diff --git a/.gitignore b/.gitignore index 60b4f3d..d66bae9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,5 @@ bin .direnv .devenv* devenv.local.nix + +.idea/ diff --git a/examples/src/authors/models.py b/examples/src/authors/models.py index 96553a5..b3b9554 100644 --- a/examples/src/authors/models.py +++ b/examples/src/authors/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses from typing import Optional diff --git a/examples/src/authors/query.py b/examples/src/authors/query.py index 019f877..d10cc65 100644 --- a/examples/src/authors/query.py +++ b/examples/src/authors/query.py @@ -1,8 +1,8 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql -from typing import AsyncIterator, Iterator, Optional +from typing import Any, AsyncIterator, Iterator, List, Optional import sqlalchemy import sqlalchemy.ext.asyncio @@ -20,6 +20,11 @@ """ +CREATE_AUTHORS_BATCH = """-- name: create_authors_batch \\:copyfrom +INSERT INTO authors (name, bio) VALUES (:p1, :p2) +""" + + DELETE_AUTHOR = """-- name: delete_author \\:exec DELETE FROM authors WHERE id = :p1 @@ -52,6 +57,10 @@ def create_author(self, *, name: str, bio: Optional[str]) -> Optional[models.Aut bio=row[2], ) + def create_authors_batch(self, arg_list: List[Any]) -> int: + result = self._conn.executemany(sqlalchemy.text(CREATE_AUTHORS_BATCH), arg_list) + return result.rowcount + def delete_author(self, *, id: int) -> None: self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id}) @@ -89,6 +98,10 @@ async def create_author(self, *, name: str, bio: Optional[str]) -> Optional[mode bio=row[2], ) + async def create_authors_batch(self, arg_list: List[Any]) -> int: + result = await self._conn.executemany(sqlalchemy.text(CREATE_AUTHORS_BATCH), arg_list) + return result.rowcount + async def delete_author(self, *, id: int) -> None: await self._conn.execute(sqlalchemy.text(DELETE_AUTHOR), {"p1": id}) diff --git a/examples/src/authors/query.sql b/examples/src/authors/query.sql index 75e38b2..e5e75cf 100644 --- a/examples/src/authors/query.sql +++ b/examples/src/authors/query.sql @@ -17,3 +17,6 @@ RETURNING *; -- name: DeleteAuthor :exec DELETE FROM authors WHERE id = $1; + +-- name: CreateAuthorsBatch :copyfrom +INSERT INTO authors (name, bio) VALUES ($1, $2); diff --git a/examples/src/booktest/models.py b/examples/src/booktest/models.py index d7ee131..dcfbc20 100644 --- a/examples/src/booktest/models.py +++ b/examples/src/booktest/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses import datetime import enum diff --git a/examples/src/booktest/query.py b/examples/src/booktest/query.py index bc71f22..12d3717 100644 --- a/examples/src/booktest/query.py +++ b/examples/src/booktest/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql import dataclasses import datetime diff --git a/examples/src/jets/models.py b/examples/src/jets/models.py index 0d4eb5d..fc5464b 100644 --- a/examples/src/jets/models.py +++ b/examples/src/jets/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses diff --git a/examples/src/jets/query-building.py b/examples/src/jets/query-building.py index 7651116..adcdcdb 100644 --- a/examples/src/jets/query-building.py +++ b/examples/src/jets/query-building.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query-building.sql from typing import AsyncIterator, Optional diff --git a/examples/src/ondeck/city.py b/examples/src/ondeck/city.py index 5af93e9..2f2da93 100644 --- a/examples/src/ondeck/city.py +++ b/examples/src/ondeck/city.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: city.sql from typing import AsyncIterator, Optional diff --git a/examples/src/ondeck/models.py b/examples/src/ondeck/models.py index 1161408..a32fea2 100644 --- a/examples/src/ondeck/models.py +++ b/examples/src/ondeck/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses import datetime import enum diff --git a/examples/src/ondeck/venue.py b/examples/src/ondeck/venue.py index 6159bf6..1911cb3 100644 --- a/examples/src/ondeck/venue.py +++ b/examples/src/ondeck/venue.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: venue.sql import dataclasses from typing import AsyncIterator, List, Optional diff --git a/examples/src/tests/test_authors.py b/examples/src/tests/test_authors.py index c3031cd..d679a62 100644 --- a/examples/src/tests/test_authors.py +++ b/examples/src/tests/test_authors.py @@ -29,6 +29,18 @@ def test_authors(db: sqlalchemy.engine.Connection): assert len(author_list) == 1 assert author_list[0] == new_author + # Test batch insert with copyfrom + batch_authors = [ + {"p1": "Dennis Ritchie", "p2": "Creator of C Programming Language"}, + {"p1": "Ken Thompson", "p2": "Creator of Unix and Go Programming Language"}, + {"p1": "Rob Pike", "p2": "Co-creator of Go Programming Language"}, + ] + rows_affected = querier.create_authors_batch(batch_authors) + assert rows_affected == 3 + + all_authors = list(querier.list_authors()) + assert len(all_authors) == 4 # 1 existing + 3 batch inserted + @pytest.mark.asyncio async def test_authors_async(async_db: sqlalchemy.ext.asyncio.AsyncConnection): @@ -54,3 +66,17 @@ async def test_authors_async(async_db: sqlalchemy.ext.asyncio.AsyncConnection): author_list.append(author) assert len(author_list) == 1 assert author_list[0] == new_author + + # Test batch insert with copyfrom + batch_authors = [ + {"p1": "Dennis Ritchie", "p2": "Creator of C Programming Language"}, + {"p1": "Ken Thompson", "p2": "Creator of Unix and Go Programming Language"}, + {"p1": "Rob Pike", "p2": "Co-creator of Go Programming Language"}, + ] + rows_affected = await querier.create_authors_batch(batch_authors) + assert rows_affected == 3 + + all_authors = [] + async for author in querier.list_authors(): + all_authors.append(author) + assert len(all_authors) == 4 # 1 existing + 3 batch inserted diff --git a/internal/endtoend/testdata/copyfrom/python/models.py b/internal/endtoend/testdata/copyfrom/python/models.py new file mode 100644 index 0000000..e728373 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/python/models.py @@ -0,0 +1,24 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.29.0 +import dataclasses +import datetime +from typing import Optional + + +@dataclasses.dataclass() +class Author: + id: int + name: str + bio: str + + +@dataclasses.dataclass() +class User: + id: int + email: str + name: str + bio: Optional[str] + age: Optional[int] + active: Optional[bool] + created_at: datetime.datetime diff --git a/internal/endtoend/testdata/copyfrom/python/query.py b/internal/endtoend/testdata/copyfrom/python/query.py new file mode 100644 index 0000000..f17f820 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/python/query.py @@ -0,0 +1,158 @@ +# Code generated by sqlc. DO NOT EDIT. +# versions: +# sqlc v1.29.0 +# source: query.sql +import dataclasses +from typing import Any, List, Optional + +import sqlalchemy +import sqlalchemy.ext.asyncio + +from copyfrom import models + + +CREATE_AUTHOR = """-- name: create_author \\:one +INSERT INTO authors (name, bio) VALUES (:p1, :p2) RETURNING id, name, bio +""" + + +CREATE_AUTHORS = """-- name: create_authors \\:copyfrom +INSERT INTO authors (name, bio) VALUES (:p1, :p2) +""" + + +CREATE_AUTHORS_NAMED = """-- name: create_authors_named \\:copyfrom +INSERT INTO authors (name, bio) VALUES (:p1, :p2) +""" + + +CREATE_USER = """-- name: create_user \\:one +INSERT INTO users (email, name) VALUES (:p1, :p2) RETURNING id, email, name, bio, age, active, created_at +""" + + +CREATE_USERS_BATCH = """-- name: create_users_batch \\:copyfrom +INSERT INTO users (email, name) VALUES (:p1, :p2) +""" + + +CREATE_USERS_WITH_DETAILS = """-- name: create_users_with_details \\:copyfrom +INSERT INTO users (email, name, bio, age, active) VALUES (:p1, :p2, :p3, :p4, :p5) +""" + + +@dataclasses.dataclass() +class CreateUsersWithDetailsParams: + email: str + name: str + bio: Optional[str] + age: Optional[int] + active: Optional[bool] + + +class Querier: + def __init__(self, conn: sqlalchemy.engine.Connection): + self._conn = conn + + def create_author(self, *, name: str, bio: str) -> Optional[models.Author]: + row = self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio}).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + def create_authors(self, arg_list: List[Any]) -> int: + result = self._conn.executemany(sqlalchemy.text(CREATE_AUTHORS), arg_list) + return result.rowcount + + def create_authors_named(self, arg_list: List[Any]) -> int: + result = self._conn.executemany(sqlalchemy.text(CREATE_AUTHORS_NAMED), arg_list) + return result.rowcount + + def create_user(self, *, email: str, name: str) -> Optional[models.User]: + row = self._conn.execute(sqlalchemy.text(CREATE_USER), {"p1": email, "p2": name}).first() + if row is None: + return None + return models.User( + id=row[0], + email=row[1], + name=row[2], + bio=row[3], + age=row[4], + active=row[5], + created_at=row[6], + ) + + def create_users_batch(self, arg_list: List[Any]) -> int: + result = self._conn.executemany(sqlalchemy.text(CREATE_USERS_BATCH), arg_list) + return result.rowcount + + def create_users_with_details(self, arg_list: List[CreateUsersWithDetailsParams]) -> int: + data = list() + for item in arg_list: + data.append({ + "p1": item.email, + "p2": item.name, + "p3": item.bio, + "p4": item.age, + "p5": item.active, + }) + result = self._conn.executemany(sqlalchemy.text(CREATE_USERS_WITH_DETAILS), data) + return result.rowcount + + +class AsyncQuerier: + def __init__(self, conn: sqlalchemy.ext.asyncio.AsyncConnection): + self._conn = conn + + async def create_author(self, *, name: str, bio: str) -> Optional[models.Author]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_AUTHOR), {"p1": name, "p2": bio})).first() + if row is None: + return None + return models.Author( + id=row[0], + name=row[1], + bio=row[2], + ) + + async def create_authors(self, arg_list: List[Any]) -> int: + result = await self._conn.executemany(sqlalchemy.text(CREATE_AUTHORS), arg_list) + return result.rowcount + + async def create_authors_named(self, arg_list: List[Any]) -> int: + result = await self._conn.executemany(sqlalchemy.text(CREATE_AUTHORS_NAMED), arg_list) + return result.rowcount + + async def create_user(self, *, email: str, name: str) -> Optional[models.User]: + row = (await self._conn.execute(sqlalchemy.text(CREATE_USER), {"p1": email, "p2": name})).first() + if row is None: + return None + return models.User( + id=row[0], + email=row[1], + name=row[2], + bio=row[3], + age=row[4], + active=row[5], + created_at=row[6], + ) + + async def create_users_batch(self, arg_list: List[Any]) -> int: + result = await self._conn.executemany(sqlalchemy.text(CREATE_USERS_BATCH), arg_list) + return result.rowcount + + async def create_users_with_details(self, arg_list: List[CreateUsersWithDetailsParams]) -> int: + data = list() + for item in arg_list: + data.append({ + "p1": item.email, + "p2": item.name, + "p3": item.bio, + "p4": item.age, + "p5": item.active, + }) + result = await self._conn.executemany(sqlalchemy.text(CREATE_USERS_WITH_DETAILS), data) + return result.rowcount diff --git a/internal/endtoend/testdata/copyfrom/query.sql b/internal/endtoend/testdata/copyfrom/query.sql new file mode 100644 index 0000000..576b14c --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/query.sql @@ -0,0 +1,17 @@ +-- name: CreateAuthors :copyfrom +INSERT INTO authors (name, bio) VALUES ($1, $2); + +-- name: CreateAuthor :one +INSERT INTO authors (name, bio) VALUES ($1, $2) RETURNING *; + +-- name: CreateAuthorsNamed :copyfrom +INSERT INTO authors (name, bio) VALUES (@name, @bio); + +-- name: CreateUser :one +INSERT INTO users (email, name) VALUES (@email, @name) RETURNING *; + +-- name: CreateUsersBatch :copyfrom +INSERT INTO users (email, name) VALUES (@email, @name); + +-- name: CreateUsersWithDetails :copyfrom +INSERT INTO users (email, name, bio, age, active) VALUES ($1, $2, $3, $4, $5); \ No newline at end of file diff --git a/internal/endtoend/testdata/copyfrom/schema.sql b/internal/endtoend/testdata/copyfrom/schema.sql new file mode 100644 index 0000000..3ce88a6 --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/schema.sql @@ -0,0 +1,15 @@ +CREATE TABLE authors ( + id SERIAL PRIMARY KEY, + name text NOT NULL, + bio text NOT NULL +); + +CREATE TABLE users ( + id SERIAL PRIMARY KEY, + email text NOT NULL, + name text NOT NULL, + bio text, + age int, + active boolean DEFAULT true, + created_at timestamp NOT NULL DEFAULT NOW() +); \ No newline at end of file diff --git a/internal/endtoend/testdata/copyfrom/sqlc.yaml b/internal/endtoend/testdata/copyfrom/sqlc.yaml new file mode 100644 index 0000000..9d0c56e --- /dev/null +++ b/internal/endtoend/testdata/copyfrom/sqlc.yaml @@ -0,0 +1,17 @@ +version: "2" +plugins: + - name: py + wasm: + url: file://../../../../bin/sqlc-gen-python.wasm + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" +sql: + - schema: "schema.sql" + queries: "query.sql" + engine: postgresql + codegen: + - out: python + plugin: py + options: + package: copyfrom + emit_sync_querier: true + emit_async_querier: true \ No newline at end of file diff --git a/internal/endtoend/testdata/emit_pydantic_models/db/models.py b/internal/endtoend/testdata/emit_pydantic_models/db/models.py index 7676e5c..61ad3eb 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/db/models.py +++ b/internal/endtoend/testdata/emit_pydantic_models/db/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import pydantic from typing import Optional diff --git a/internal/endtoend/testdata/emit_pydantic_models/db/query.py b/internal/endtoend/testdata/emit_pydantic_models/db/query.py index 6f5b76f..cc36118 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/db/query.py +++ b/internal/endtoend/testdata/emit_pydantic_models/db/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql from typing import AsyncIterator, Iterator, Optional diff --git a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml index beae200..456ccf2 100644 --- a/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml +++ b/internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/emit_str_enum/db/models.py b/internal/endtoend/testdata/emit_str_enum/db/models.py index 5fdf754..aa43ab1 100644 --- a/internal/endtoend/testdata/emit_str_enum/db/models.py +++ b/internal/endtoend/testdata/emit_str_enum/db/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses import enum from typing import Optional diff --git a/internal/endtoend/testdata/emit_str_enum/db/query.py b/internal/endtoend/testdata/emit_str_enum/db/query.py index 8082889..5ea0264 100644 --- a/internal/endtoend/testdata/emit_str_enum/db/query.py +++ b/internal/endtoend/testdata/emit_str_enum/db/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql from typing import AsyncIterator, Iterator, Optional diff --git a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml index 04e3feb..62296ae 100644 --- a/internal/endtoend/testdata/emit_str_enum/sqlc.yaml +++ b/internal/endtoend/testdata/emit_str_enum/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/exec_result/python/models.py b/internal/endtoend/testdata/exec_result/python/models.py index 034fb2d..6d3e9f5 100644 --- a/internal/endtoend/testdata/exec_result/python/models.py +++ b/internal/endtoend/testdata/exec_result/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses diff --git a/internal/endtoend/testdata/exec_result/python/query.py b/internal/endtoend/testdata/exec_result/python/query.py index b68ce39..c9c6e21 100644 --- a/internal/endtoend/testdata/exec_result/python/query.py +++ b/internal/endtoend/testdata/exec_result/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/exec_result/sqlc.yaml b/internal/endtoend/testdata/exec_result/sqlc.yaml index ddffc83..0e7eb1a 100644 --- a/internal/endtoend/testdata/exec_result/sqlc.yaml +++ b/internal/endtoend/testdata/exec_result/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/exec_rows/python/models.py b/internal/endtoend/testdata/exec_rows/python/models.py index 034fb2d..6d3e9f5 100644 --- a/internal/endtoend/testdata/exec_rows/python/models.py +++ b/internal/endtoend/testdata/exec_rows/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses diff --git a/internal/endtoend/testdata/exec_rows/python/query.py b/internal/endtoend/testdata/exec_rows/python/query.py index 7a9b2a6..a678f3d 100644 --- a/internal/endtoend/testdata/exec_rows/python/query.py +++ b/internal/endtoend/testdata/exec_rows/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/exec_rows/sqlc.yaml b/internal/endtoend/testdata/exec_rows/sqlc.yaml index ddffc83..0e7eb1a 100644 --- a/internal/endtoend/testdata/exec_rows/sqlc.yaml +++ b/internal/endtoend/testdata/exec_rows/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py b/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py index 8ba8803..fc76620 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py +++ b/internal/endtoend/testdata/inflection_exclude_table_names/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py b/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py index 1e1e161..1fc92fd 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py +++ b/internal/endtoend/testdata/inflection_exclude_table_names/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql from typing import Optional diff --git a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml index efbb150..47daf09 100644 --- a/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml +++ b/internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_two/python/models.py b/internal/endtoend/testdata/query_parameter_limit_two/python/models.py index 059675d..89c0f8d 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_two/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_two/python/query.py b/internal/endtoend/testdata/query_parameter_limit_two/python/query.py index e8b723e..0d9bd97 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_two/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml index 336bca7..e5f79f7 100644 --- a/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py b/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py index 30e80db..dc09dab 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py b/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py index 5a1fbbc..49b7bd1 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql import sqlalchemy import sqlalchemy.ext.asyncio diff --git a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml index c20cd57..d4db347 100644 --- a/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py b/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py index 059675d..89c0f8d 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py +++ b/internal/endtoend/testdata/query_parameter_limit_zero/python/models.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py b/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py index 47bd6a9..38e0efb 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py +++ b/internal/endtoend/testdata/query_parameter_limit_zero/python/query.py @@ -1,6 +1,6 @@ # Code generated by sqlc. DO NOT EDIT. # versions: -# sqlc v1.28.0 +# sqlc v1.29.0 # source: query.sql import dataclasses diff --git a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml index 6e2cdeb..332f2b9 100644 --- a/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" sql: - schema: schema.sql queries: query.sql diff --git a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml index c432e4f..5c20b23 100644 --- a/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml +++ b/internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml @@ -3,7 +3,7 @@ plugins: - name: py wasm: url: file://../../../../bin/sqlc-gen-python.wasm - sha256: "d6846ffad948181e611e883cedd2d2be66e091edc1273a0abc6c9da18399e0ca" + sha256: "9aedc973afcc3c089934aaf509843a761e21ac92a4ce34d7a4ba3acebfb49bf0" sql: - schema: schema.sql queries: query.sql diff --git a/internal/gen.go b/internal/gen.go index 6e50fae..716d629 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -2,7 +2,7 @@ package python import ( "context" - json "encoding/json" + "encoding/json" "errors" "fmt" "log" @@ -10,7 +10,6 @@ import ( "sort" "strings" - "github.com/sqlc-dev/plugin-sdk-go/metadata" "github.com/sqlc-dev/plugin-sdk-go/plugin" "github.com/sqlc-dev/plugin-sdk-go/sdk" @@ -135,6 +134,33 @@ type Query struct { } func (q Query) AddArgs(args *pyast.Arguments) { + switch q.Cmd { + case ":copyfrom": + q.addCopyFromArgs(args) + default: + q.addRegularArgs(args) + } +} + +// addCopyFromArgs adds arguments for :copyfrom commands +func (q Query) addCopyFromArgs(args *pyast.Arguments) { + // Check if we have a struct parameter + if len(q.Args) == 1 && q.Args[0].IsStruct() { + args.Args = append(args.Args, &pyast.Arg{ + Arg: q.Args[0].Name + "_list", + Annotation: subscriptNode("List", q.Args[0].Annotation()), + }) + } else { + // Fall back to List[Any] for individual parameters + args.Args = append(args.Args, &pyast.Arg{ + Arg: "arg_list", + Annotation: subscriptNode("List", poet.Name("Any")), + }) + } +} + +// addRegularArgs adds arguments for regular (non-copyfrom) commands +func (q Query) addRegularArgs(args *pyast.Arguments) { // A single struct arg does not need to be passed as a keyword argument if len(q.Args) == 1 && q.Args[0].IsStruct() { args.Args = append(args.Args, &pyast.Arg{ @@ -143,6 +169,8 @@ func (q Query) AddArgs(args *pyast.Arguments) { }) return } + + // Multiple args or non-struct args are passed as keyword arguments for _, a := range q.Args { args.KwOnlyArgs = append(args.KwOnlyArgs, &pyast.Arg{ Arg: a.Name, @@ -180,6 +208,79 @@ func (q Query) ArgDictNode() *pyast.Node { } } +// BuildCopyFromBody generates the method body for :copyfrom commands. +func (q Query) BuildCopyFromBody(isAsync bool) []*pyast.Node { + var body []*pyast.Node + + dataVar := "arg_list" + if len(q.Args) == 1 && q.Args[0].IsStruct() { + argName := q.Args[0].Name + "_list" + dataVar = "data" + body = append(body, q.buildStructToDictList(argName, dataVar)...) + } + + sqlText := poet.Node(&pyast.Call{ + Func: poet.Attribute(poet.Name("sqlalchemy"), "text"), + Args: []*pyast.Node{poet.Name(q.ConstantName)}, + }) + + execCall := poet.Node(&pyast.Call{ + Func: poet.Attribute(poet.Name("self._conn"), "executemany"), + Args: []*pyast.Node{ + sqlText, + poet.Name(dataVar), + }, + }) + + if isAsync { + execCall = poet.Await(execCall) + } + + body = append(body, + assignNode("result", execCall), + poet.Return(poet.Attribute(poet.Name("result"), "rowcount")), + ) + + return body +} + +// buildStructToDictList converts a list of parameter structs to a list of dicts for SQLAlchemy +func (q Query) buildStructToDictList(sourceVar, targetVar string) []*pyast.Node { + var body []*pyast.Node + + body = append(body, assignNode(targetVar, poet.Node(&pyast.Call{ + Func: poet.Name("list"), + Args: []*pyast.Node{}, + }))) + + loopVar := "item" + dict := &pyast.Dict{} + for i, field := range q.Args[0].Struct.Fields { + paramName := fmt.Sprintf("p%v", i+1) + dict.Keys = append(dict.Keys, poet.Constant(paramName)) + dict.Values = append(dict.Values, poet.Attribute(poet.Name(loopVar), field.Name)) + } + + body = append(body, poet.Node(&pyast.For{ + Target: poet.Name(loopVar), + Iter: poet.Name(sourceVar), + Body: []*pyast.Node{ + poet.Node(&pyast.Call{ + Func: poet.Attribute(poet.Name(targetVar), "append"), + Args: []*pyast.Node{ + { + Node: &pyast.Node_Dict{ + Dict: dict, + }, + }, + }, + }), + }, + })) + + return body +} + func makePyType(req *plugin.GenerateRequest, col *plugin.Column) pyType { typ := pyInnerType(req, col) return pyType{ @@ -372,9 +473,6 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ if query.Cmd == "" { continue } - if query.Cmd == metadata.CmdCopyFrom { - return nil, errors.New("Support for CopyFrom in Python is not implemented") - } methodName := methodName(query.Name) @@ -403,11 +501,13 @@ func buildQueries(conf Config, req *plugin.GenerateRequest, structs []Struct) ([ Column: p.Column, }) } - gq.Args = []QueryValue{{ - Emit: true, - Name: "arg", - Struct: columnsToStruct(req, query.Name+"Params", cols), - }} + gq.Args = []QueryValue{ + { + Emit: true, + Name: "arg", + Struct: columnsToStruct(req, query.Name+"Params", cols), + }, + } } else { args := make([]QueryValue, 0, len(query.Params)) for _, p := range query.Params { @@ -959,6 +1059,10 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { poet.Return(exec), ) f.Returns = typeRefNode("sqlalchemy", "engine", "Result") + case ":copyfrom": + // For copyfrom, use executemany for batch inserts + f.Body = append(f.Body, q.BuildCopyFromBody(false)...) + f.Returns = poet.Name("int") default: panic("unknown cmd " + q.Cmd) } @@ -1052,6 +1156,10 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node { poet.Return(poet.Await(exec)), ) f.Returns = typeRefNode("sqlalchemy", "engine", "Result") + case ":copyfrom": + // For async copyfrom, use executemany for batch inserts + f.Body = append(f.Body, q.BuildCopyFromBody(true)...) + f.Returns = poet.Name("int") default: panic("unknown cmd " + q.Cmd) } diff --git a/internal/imports.go b/internal/imports.go index b88c58c..b066b6f 100644 --- a/internal/imports.go +++ b/internal/imports.go @@ -161,6 +161,13 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map std["typing.AsyncIterator"] = importSpec{Module: "typing", Name: "AsyncIterator"} } } + if q.Cmd == ":copyfrom" { + std["typing.List"] = importSpec{Module: "typing", Name: "List"} + // Add Any if non-struct args + if !(len(q.Args) == 1 && q.Args[0].IsStruct()) { + std["typing.Any"] = importSpec{Module: "typing", Name: "Any"} + } + } queryValueModelImports(q.Ret) for _, qv := range q.Args { queryValueModelImports(qv)