/
trigger_modelchange.go
149 lines (128 loc) · 4.42 KB
/
trigger_modelchange.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package migration
import (
"context"
"text/template"
"ariga.io/atlas/sql/migrate"
"entgo.io/ent/dialect"
"entgo.io/ent/dialect/sql/schema"
)
const (
modelChangeFunction = "model_change"
ModelChangeChannel = modelChangeFunction + "_channel"
)
// ApplyModelChangeTrigger returns the schema.MigrateOption slice for
// creating the model change trigger,
// which is used to notify the server with ModelChangeChannel when the database is changed.
func ApplyModelChangeTrigger(tableNames []string) (opts []schema.MigrateOption) {
hooks := []schema.ApplyHook{
createModelChangeFunction(),
createModelChangeTrigger(tableNames),
}
for i := range hooks {
opts = append(opts, schema.WithApplyHook(hooks[i]))
}
return
}
func createModelChangeFunction() schema.ApplyHook {
const tmpl = `
CREATE OR REPLACE FUNCTION {{ .Function }}() RETURNS TRIGGER
LANGUAGE plpgsql
AS
$$
DECLARE
ids jsonb;
ids_length int;
column_sum int;
payload text;
BEGIN
-- Check if the table has project_id and environment_id columns.
-- For entity tables, project_id is always present when environment_id is present.
SELECT SUM(
CASE
WHEN column_name = 'name' then 100
WHEN column_name = 'project_id' then 10
WHEN column_name = 'environment_id' then 1
ELSE 0
END)
FROM information_schema.columns
WHERE table_name = TG_TABLE_NAME
INTO column_sum;
-- Build the ID list.
ids := '[]'::jsonb;
CASE TG_OP
WHEN 'INSERT', 'UPDATE' THEN
EXECUTE 'SELECT jsonb_agg(jsonb_build_object(''id'', id::text' ||
CASE WHEN column_sum >= 100 THEN ', ''name'', name::text' ELSE '' END ||
CASE WHEN column_sum % 100 >= 10 THEN ', ''project_id'', project_id::text' ELSE '' END ||
CASE WHEN column_sum % 10 >= 1 THEN ', ''environment_id'', environment_id::text' ELSE '' END ||
')) FROM new_table' INTO ids;
WHEN 'DELETE' THEN
EXECUTE 'SELECT jsonb_agg(jsonb_build_object(''id'', id::text' ||
CASE WHEN column_sum >= 100 THEN ', ''name'', name::text' ELSE '' END ||
CASE WHEN column_sum % 100 >= 10 THEN ', ''project_id'', project_id::text' ELSE '' END ||
CASE WHEN column_sum % 10 >= 1 THEN ', ''environment_id'', environment_id::text' ELSE '' END ||
')) FROM old_table' INTO ids;
ELSE
RAISE EXCEPTION 'Unknown Operation';
END CASE;
-- Validate the length of ID list.
ids_length := jsonb_array_length(ids);
IF (ids_length = 0) THEN
RETURN NULL;
END IF;
-- Build the notification payload.
payload := json_build_object(
'ts', current_timestamp,
'op', lower(TG_OP),
'tb_s', TG_TABLE_SCHEMA,
'tb_n', TG_TABLE_NAME,
'ids', ids)::text;
-- Notify the channel.
PERFORM pg_notify('{{ .Channel }}', payload);
RETURN NULL;
END;
$$;
`
tpl := template.Must(template.New("tmpl").Parse(tmpl))
return func(next schema.Applier) schema.Applier {
return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
plan.Changes = append(plan.Changes, &migrate.Change{
Cmd: executeTemplate(tpl, map[string]any{
"Function": modelChangeFunction,
"Channel": ModelChangeChannel,
}),
})
return next.Apply(ctx, conn, plan)
})
}
}
func createModelChangeTrigger(tableNames []string) schema.ApplyHook {
const tmpl = `
CREATE OR REPLACE TRIGGER {{ .Function }}_ins
AFTER INSERT ON {{ .Table }}
REFERENCING NEW TABLE AS new_table
FOR EACH STATEMENT EXECUTE FUNCTION {{ .Function }}();
CREATE OR REPLACE TRIGGER {{ .Function }}_upd
AFTER UPDATE ON {{ .Table }}
REFERENCING OLD TABLE AS old_table NEW TABLE AS new_table
FOR EACH STATEMENT EXECUTE FUNCTION {{ .Function }}();
CREATE OR REPLACE TRIGGER {{ .Function }}_del
AFTER DELETE ON {{ .Table }}
REFERENCING OLD TABLE AS old_table
FOR EACH STATEMENT EXECUTE FUNCTION {{ .Function }}();
`
tpl := template.Must(template.New("tmpl").Parse(tmpl))
return func(next schema.Applier) schema.Applier {
return schema.ApplyFunc(func(ctx context.Context, conn dialect.ExecQuerier, plan *migrate.Plan) error {
for i := range tableNames {
plan.Changes = append(plan.Changes, &migrate.Change{
Cmd: executeTemplate(tpl, map[string]any{
"Function": modelChangeFunction,
"Table": tableNames[i],
}),
})
}
return next.Apply(ctx, conn, plan)
})
}
}