Skip to content

Commit

Permalink
ensure standard locations
Browse files Browse the repository at this point in the history
  • Loading branch information
mrmasterplan committed Aug 25, 2023
1 parent 0cfc2e5 commit 623ec33
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 6 deletions.
9 changes: 9 additions & 0 deletions src/spetlr/deltaspec/DatabricksLocation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from urllib.parse import urlparse


def standard_databricks_location(val: str) -> str:
p = urlparse(val)
if not p.scheme:
p = p._replace(scheme="dbfs")

return p.geturl()
15 changes: 14 additions & 1 deletion src/spetlr/deltaspec/DeltaDatabaseSpec.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from spetlr import Configurator
from spetlr.configurator.sql.parse_sql import parse_single_sql_statement
from spetlr.deltaspec.DatabricksLocation import standard_databricks_location
from spetlr.deltaspec.exceptions import InvalidSpecificationError
from spetlr.exceptions import NoSuchValueException
from spetlr.spark import Spark
Expand All @@ -17,6 +18,18 @@ class DeltaDatabaseSpec:
location: Optional[str] = None
dbproperties: Dict[str, str] = None

def __init__(
self,
name: str,
comment: Optional[str] = None,
location: Optional[str] = None,
dbproperties: Dict[str, str] = None,
):
self.name = name
self.comment = comment
self.location = standard_databricks_location(location)
self.dbproperties = dbproperties

def __repr__(self):
dbproperties_part = ""
if self.dbproperties:
Expand All @@ -29,7 +42,7 @@ def __repr__(self):
", ".join(
part
for part in [
f"DbSpec(name={repr(self.name)}",
f"DeltaDatabaseSpec(name={repr(self.name)}",
(f"comment={repr(self.comment)}" if self.comment else ""),
(f"location={repr(self.location)}" if self.location else ""),
dbproperties_part,
Expand Down
24 changes: 24 additions & 0 deletions src/spetlr/deltaspec/DeltaTableSpec.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from spetlr import Configurator
from spetlr.configurator.sql.parse_sql import parse_single_sql_statement
from spetlr.delta import DeltaHandle
from spetlr.deltaspec.DatabricksLocation import standard_databricks_location
from spetlr.deltaspec.DeltaTableDifference import DeltaTableDifference
from spetlr.deltaspec.exceptions import (
InvalidSpecificationError,
Expand Down Expand Up @@ -37,6 +38,29 @@ class DeltaTableSpec:
comment: str = None
# TODO: Clustered By

def __init__(
self,
name: str,
schema: StructType,
options: Dict[str, str] = None,
partitioned_by: List[str] = None,
tblproperties: Dict[str, str] = None,
location: Optional[str] = None,
comment: str = None,
):
self.name = name
self.schema = schema
self.options = options or dict()
self.partitioned_by = partitioned_by or list()
for col in self.partitioned_by:
if col not in self.schema.names:
raise InvalidSpecificationError(
"Supply the partitioning columns in the schema."
)
self.tblproperties = tblproperties or dict()
self.location = standard_databricks_location(location)
self.comment = comment

# Non-trivial constructors
@classmethod
def from_sql(cls, sql: str) -> "DeltaTableSpec":
Expand Down
9 changes: 4 additions & 5 deletions tests/cluster/delta/deltaspec/test_tblspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def setUpClass(cls) -> None:
)

def test_01_diff_alter_statements(self):
Configurator().set_prod()
forward_diff = self.target.compare_to(self.base)
self.assertEqual(
forward_diff.alter_table_statements(),
Expand Down Expand Up @@ -83,6 +84,7 @@ def test_01_diff_alter_statements(self):
)

def test_02_execute_alter_statements(self):
Configurator().set_debug()
spark = Spark.get()
spark.sql(
f"""
Expand All @@ -91,19 +93,16 @@ def test_02_execute_alter_statements(self):
)

self.assertTrue(self.base.compare_to_storage().is_different())

self.base.make_storage_match()

self.assertFalse(self.base.compare_to_storage().is_different())
self.assertTrue(self.target.compare_to_storage().is_different())

self.assertTrue(self.target.compare_to_storage().is_different())
self.target.make_storage_match()

self.assertTrue(self.base.compare_to_storage().is_different())
self.assertFalse(self.target.compare_to_storage().is_different())

spark.sql(
f"""
DROP DATABASE {Configurator().get('mydb','name')} CASCADE;
"""
"""
)

0 comments on commit 623ec33

Please sign in to comment.