Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Added trigger example

  • Loading branch information...
commit b6e4c8d9f0939b8ffc0a9c2041cc1dcd4e5c1e69 1 parent f5ec3b0
@rdunklau authored
View
2  pytoplpython/__init__.py
@@ -1,11 +1,9 @@
from importlib.abc import InspectLoader
-from importlib.machinery import PathFinder
import imp
import sys
import ast
from .unparse import Unparser
from io import StringIO
-from code import compile_command
from tempfile import NamedTemporaryFile
def postgresql_function(function):
View
22 pytoplpython/unparse.py
@@ -15,15 +15,20 @@ def __init__(self, name, return_type, args_definition, code):
self.args_definition = args_definition
self.code = code
+def process_type(compiler, argtype):
+ if isinstance(argtype, str):
+ return argtype
+ else:
+ return compiler.dialect.type_compiler.process(argtype)
+
@compiles(CreateFunction, 'postgresql')
def visit_create_function(element, compiler, **kw):
code = 'create or replace function %s (' % element.name
args = []
- type_compiler = compiler.dialect.type_compiler
for arg, argtype in element.args_definition:
- args.append('%s %s' % (arg, type_compiler.process(argtype)))
+ args.append('%s %s' % (arg, process_type(compiler, argtype)))
code += ', '.join(args)
- code += ' ) RETURNS %s' % (type_compiler.process(element.return_type))
+ code += ' ) RETURNS %s' % (process_type(compiler, element.return_type))
code += ' AS $$\n'
code += element.code
code += '$$ language plpythonu;'
@@ -268,11 +273,12 @@ def _ClassDef(self, t):
def _get_type(self, python_type):
# TODO: use sqlalchemy to return that
name = python_type.id.split('.')[-1]
- try:
- return getattr(types, name)()
- except AttributeError:
- raise AttributeError('%s is not a sqlalchemy type (at line %s)'
- % (name, python_type.lineno))
+ type = getattr(types, name, None)
+ if type is None:
+ type = name
+ else:
+ type = type()
+ return type
def _FunctionDef(self, t):
View
12 test/basetest.py
@@ -15,6 +15,7 @@
Column('test', Unicode),
Column('test2', Unicode))
+table.drop(checkfirst=True)
table.create(checkfirst=True)
for i in range(20):
@@ -22,6 +23,15 @@
print(engine.execute(testmodule.pyconcat(table.c.test, table.c.test2)).fetchall())
+statement = """
+CREATE TRIGGER mytrigger
+BEFORE INSERT
+ON %s
+FOR EACH ROW EXECUTE PROCEDURE %s();
+"""
+engine.execute(statement % (table.name, testmodule.nullifying_trigger.__name__))
-print(engine.execute(testmodule.pygreatest(1, 3)).fetchone())
+table.insert({'test': 'grou', 'test2': 'grou'}).execute()
+
+print(engine.execute(testmodule.pyconcat(table.c.test, table.c.test2)).fetchall());
View
5 test/testmodule.py
@@ -13,4 +13,7 @@ def pyconcat(col1: Unicode, col2: Unicode) -> Unicode:
def pygreatest(col1: Integer, col2: Integer) -> Integer:
return max(col1, col2)
-
+@postgresql_function
+def nullifying_trigger() -> TRIGGER:
+ TD['new']['test2'] = 'triggered by me'
+ return 'MODIFY'

0 comments on commit b6e4c8d

Please sign in to comment.
Something went wrong with that request. Please try again.