In [1]:
import yaml
from pyspark.sql.types import StringType, FloatType, IntegerType, StructType, StructField,  DateType, TimestampType

class TableManager:
    def __init__(self, config_file_path):
        self.config_file_path = config_file_path
      
        # Load YAML configuration
        with open(self.config_file_path, 'r') as file:
            self.config = yaml.safe_load(file)
       
    
    def get_table_config(self, table_name):
        # Retrieve table configuration
        if "tables" not in self.config or table_name not in self.config["tables"]:
            raise ValueError(f"Table {table_name} not found in configuration.")
        return self.config["tables"][table_name]
    
    def get_struct_type(self, table_name):
        # Map YAML types to PySpark types
        type_mapping = {
            "StringType": StringType,
            "FloatType": FloatType,
            "IntegerType": IntegerType,
            'DateType': DateType(),
            'TimestampType': TimestampType(),
        }
        table_config = self.get_table_config(table_name)
        schema_config = table_config["schema"]
        
        # Build StructType schema
        fields = [
            StructField(field["name"], type_mapping[field["type"]](), field["nullable"])
            for field in schema_config
        ]
        return StructType(fields)

    def get_column_list(self, table_name):
        # Extract the 'name' attribute from each field in the schema
        table_config = self.get_table_config(table_name)
        schema_config = table_config["schema"]
        column_list = [field['name'] for field in schema_config]
        return column_list
    
    def get_create_table_query(self, table_name):
        table_config = self.get_table_config(table_name)
        schema = self.get_struct_type(table_name)
        partition_by = table_config.get("partition_by", [])
        
        # Generate SQL columns
        columns = ", ".join([f"{field.name} {field.dataType.simpleString()}" for field in schema.fields])
        partitioning = ", ".join([p["field"] for p in partition_by]) if partition_by else ""
        
        # Generate CREATE TABLE query
        create_table_query = f"""
        CREATE TABLE IF NOT EXISTS {table_name} ({columns})
        """
        if partitioning:
            create_table_query += f" PARTITIONED BY ({partitioning})"
        return create_table_query.strip()