Skip to content

Commit

Permalink
Initial attempt at insert/replace for /-/create, refs #1927
Browse files Browse the repository at this point in the history
  • Loading branch information
simonw committed Dec 3, 2022
1 parent 5d3916c commit ac9a843
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 124 deletions.
139 changes: 15 additions & 124 deletions datasette/views/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def __getitem__(self, key):
class TableCreateView(BaseView):
name = "table-create"

_valid_keys = {"table", "rows", "row", "columns", "pk"}
_valid_keys = {"table", "rows", "row", "columns", "pk", "pks", "ignore", "replace"}
_supported_column_types = {
"text",
"integer",
Expand Down Expand Up @@ -596,130 +596,17 @@ async def post(self, request):
if invalid_keys:
return _error(["Invalid keys: {}".format(", ".join(invalid_keys))])

table_name = data.get("table")
if not table_name:
return _error(["Table is required"])

if not self._table_name_re.match(table_name):
return _error(["Invalid table name"])

columns = data.get("columns")
rows = data.get("rows")
row = data.get("row")
if not columns and not rows and not row:
return _error(["columns, rows or row is required"])

if rows and row:
return _error(["Cannot specify both rows and row"])

if columns:
if rows or row:
return _error(["Cannot specify columns with rows or row"])
if not isinstance(columns, list):
return _error(["columns must be a list"])
for column in columns:
if not isinstance(column, dict):
return _error(["columns must be a list of objects"])
if not column.get("name") or not isinstance(column.get("name"), str):
return _error(["Column name is required"])
if not column.get("type"):
column["type"] = "text"
if column["type"] not in self._supported_column_types:
return _error(
["Unsupported column type: {}".format(column["type"])]
)
# No duplicate column names
dupes = {c["name"] for c in columns if columns.count(c) > 1}
if dupes:
return _error(["Duplicate column name: {}".format(", ".join(dupes))])

if row:
rows = [row]

if rows:
if not isinstance(rows, list):
return _error(["rows must be a list"])
for row in rows:
if not isinstance(row, dict):
return _error(["rows must be a list of objects"])

pk = data.get("pk")
if pk:
if not isinstance(pk, str):
return _error(["pk must be a string"])

def create_table(conn):
table = sqlite_utils.Database(conn)[table_name]
if rows:
table.insert_all(rows, pk=pk)
else:
table.create(
{c["name"]: c["type"] for c in columns},
pk=pk,
)
return table.schema
# ignore and replace are mutually exclusive
if data.get("ignore") and data.get("replace"):
return _error(["ignore and replace are mutually exclusive"])

try:
schema = await db.execute_write_fn(create_table)
except Exception as e:
return _error([str(e)])
table_url = self.ds.absolute_url(
request, self.ds.urls.table(db.name, table_name)
)
table_api_url = self.ds.absolute_url(
request, self.ds.urls.table(db.name, table_name, format="json")
)
details = {
"ok": True,
"database": db.name,
"table": table_name,
"table_url": table_url,
"table_api_url": table_api_url,
"schema": schema,
}
if rows:
details["row_count"] = len(rows)
return Response.json(details, status=201)


class TableCreateView(BaseView):
name = "table-create"

_valid_keys = {"table", "rows", "row", "columns", "pk", "pks"}
_supported_column_types = {
"text",
"integer",
"float",
"blob",
}
# Any string that does not contain a newline or start with sqlite_
_table_name_re = re.compile(r"^(?!sqlite_)[^\n]+$")

def __init__(self, datasette):
self.ds = datasette

async def post(self, request):
db = await self.ds.resolve_database(request)
database_name = db.name

# Must have create-table permission
if not await self.ds.permission_allowed(
request.actor, "create-table", resource=database_name
):
return _error(["Permission denied"], 403)
# ignore and replace only allowed with row or rows
if "ignore" in data or "replace" in data:
if not data.get("row") and not data.get("rows"):
return _error(["ignore and replace require row or rows"])

body = await request.post_body()
try:
data = json.loads(body)
except json.JSONDecodeError as e:
return _error(["Invalid JSON: {}".format(e)])

if not isinstance(data, dict):
return _error(["JSON must be an object"])

invalid_keys = set(data.keys()) - self._valid_keys
if invalid_keys:
return _error(["Invalid keys: {}".format(", ".join(invalid_keys))])
ignore = data.get("ignore")
replace = data.get("replace")

table_name = data.get("table")
if not table_name:
Expand Down Expand Up @@ -783,10 +670,14 @@ async def post(self, request):
if not isinstance(pk, str):
return _error(["pks must be a list of strings"])

# If table exists already, read pks from that instead
if await db.table_exists(table_name):
pks = await db.primary_keys(table_name)

def create_table(conn):
table = sqlite_utils.Database(conn)[table_name]
if rows:
table.insert_all(rows, pk=pks or pk)
table.insert_all(rows, pk=pks or pk, ignore=ignore, replace=replace)
else:
table.create(
{c["name"]: c["type"] for c in columns},
Expand Down
96 changes: 96 additions & 0 deletions tests/test_api_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,34 @@ async def test_drop_table(ds_write, scenario):
400,
{"ok": False, "errors": ["pks must be a list of strings"]},
),
# Error: ignore and replace are mutually exclusive
(
{
"table": "bad",
"row": {"id": 1, "name": "Row 1"},
"pk": "id",
"ignore": True,
"replace": True,
},
400,
{
"ok": False,
"errors": ["ignore and replace are mutually exclusive"],
},
),
# ignore and replace require row or rows
(
{
"table": "bad",
"columns": [{"name": "id", "type": "integer"}],
"ignore": True,
},
400,
{
"ok": False,
"errors": ["ignore and replace require row or rows"],
},
),
),
)
async def test_create_table(ds_write, input, expected_status, expected_response):
Expand All @@ -932,6 +960,74 @@ async def test_create_table(ds_write, input, expected_status, expected_response)
assert data == expected_response


@pytest.mark.asyncio
@pytest.mark.parametrize(
"input,expected_rows_after",
(
(
{
"table": "test_insert_replace",
"rows": [
{"id": 1, "name": "Row 1 new"},
{"id": 3, "name": "Row 3 new"},
],
"ignore": True,
},
[
{"id": 1, "name": "Row 1"},
{"id": 2, "name": "Row 2"},
{"id": 3, "name": "Row 3 new"},
],
),
(
{
"table": "test_insert_replace",
"rows": [
{"id": 1, "name": "Row 1 new"},
{"id": 3, "name": "Row 3 new"},
],
"replace": True,
},
[
{"id": 1, "name": "Row 1 new"},
{"id": 2, "name": "Row 2"},
{"id": 3, "name": "Row 3 new"},
],
),
),
)
async def test_create_table_ignore_replace(ds_write, input, expected_rows_after):
# Create table with two rows
token = write_token(ds_write)
first_response = await ds_write.client.post(
"/data/-/create",
json={
"rows": [{"id": 1, "name": "Row 1"}, {"id": 2, "name": "Row 2"}],
"table": "test_insert_replace",
"pk": "id",
},
headers={
"Authorization": "Bearer {}".format(token),
"Content-Type": "application/json",
},
)
assert first_response.status_code == 201

# Try a second time
second_response = await ds_write.client.post(
"/data/-/create",
json=input,
headers={
"Authorization": "Bearer {}".format(token),
"Content-Type": "application/json",
},
)
assert second_response.status_code == 201
# Check that the rows are as expected
rows = await ds_write.client.get("/data/test_insert_replace.json?_shape=array")
assert rows.json() == expected_rows_after


@pytest.mark.asyncio
@pytest.mark.parametrize(
"path",
Expand Down

0 comments on commit ac9a843

Please sign in to comment.