In [2]:
import yaml
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, FloatType, IntegerType

# class RegisteredTables:
#     from pyspark.sql.types import StructType, StructField, StringType, FloatType, IntegerType
#     def __init__(self, zone, table, config_file_path):
#         self.zone=zone
#         self.table=table
#         self.config_file_path=config_file_path
        
#         with open(self.config_file_path, 'r') as file:
#             config=yaml.safe_load(file)
#         self.table_schema=config[zone][table]['schema']  
#         self.table_partition_by=config[zone][table].get('partition_by',[])

#     # Convert YAML schema to PySpark StructType
#     def get_struct_type(self):
#         # Map YAML types to PySpark types
#         type_mapping = {
#                         "StringType": StringType,
#                         "FloatType": FloatType,
#                         "IntegerType": IntegerType,
#                     }        
#         fields = [
#             StructField(field["name"], type_mapping[field["type"]](), field["nullable"])
#             for field in self.table_schema
#         ]
        
#         return StructType(fields)

#     def get_schema_columns(self):
#         schema_columns = ", ".join([f"{field.name} {field.dataType.simpleString()}" for field in self.get_struct_type()])
#         return schema_columns
    
#     def get_partition_columns(self):
#         partition_columns = ", ".join([p["field"] for p in self.table_partition_by]) if self.table_partition_by else ""
#         return partition_columns
        
#     def get_column_list(self):
#         # Extract the 'name' attribute from each field in the schema
#         column_list = [field['name'] for field in self.table_schema]
#         return column_list
    
#     def get_create_table_script(self):        
#         # Create the Iceberg table
#         script = f"""
#         CREATE TABLE IF NOT EXISTS {self.table} ({self.get_schema_columns()})
#         USING iceberg
#         """
#         if self.get_partition_columns():
#             script += f" PARTITIONED BY ({self.get_partition_columns()})"         
#         return script

In [3]:
# zone='raw'
# sink_table='raw.stock_eod_yfinance'
# config_file_path='registered_table_schemas.yaml'
# x=RegisteredTables(zone, sink_table, config_file_path)
# # x=RegisteredTables('raw', 'raw.stock_eod_yfinance', 'registered_table_schemas.yaml')    
# print(x.get_struct_type())      
# # print(x.get_create_table_script()) 

In [4]:
import yaml
from pyspark.sql.types import StringType, FloatType, IntegerType, StructType, StructField

class TableManager:
    def __init__(self, config_file_path):
        self.config_file_path = config_file_path
      
        # Load YAML configuration
        with open(self.config_file_path, 'r') as file:
            self.config = yaml.safe_load(file)
       
    
    def get_table_config(self, table_name):
        # Retrieve table configuration
        if "tables" not in self.config or table_name not in self.config["tables"]:
            raise ValueError(f"Table {table_name} not found in configuration.")
        return self.config["tables"][table_name]
    
    def get_struct_type(self, table_name):
        # Map YAML types to PySpark types
        type_mapping = {
            "StringType": StringType,
            "FloatType": FloatType,
            "IntegerType": IntegerType,
        }
        table_config = self.get_table_config(table_name)
        schema_config = table_config["schema"]
        
        # Build StructType schema
        fields = [
            StructField(field["name"], type_mapping[field["type"]](), field["nullable"])
            for field in schema_config
        ]
        return StructType(fields)

    def get_column_list(self, table_name):
        # Extract the 'name' attribute from each field in the schema
        table_config = self.get_table_config(table_name)
        schema_config = table_config["schema"]
        column_list = [field['name'] for field in schema_config]
        return column_list
    
    def get_create_table_query(self, table_name):
        table_config = self.get_table_config(table_name)
        schema = self.get_struct_type(table_name)
        partition_by = table_config.get("partition_by", [])
        
        # Generate SQL columns
        columns = ", ".join([f"{field.name} {field.dataType.simpleString()}" for field in schema.fields])
        partitioning = ", ".join([p["field"] for p in partition_by]) if partition_by else ""
        
        # Generate CREATE TABLE query
        create_table_query = f"""
        CREATE TABLE IF NOT EXISTS {table_name} ({columns})
        """
        if partitioning:
            create_table_query += f" PARTITIONED BY ({partitioning})"
        return create_table_query.strip()

In [5]:
# table_name='raw.stock_eod_yfinance'
# config_file_path='registered_table_schemas.yaml'
# x=TableManager(config_file_path)
# print(x.get_column_list('nessie.raw.stock_eod_yfinance'))

# # x=RegisteredTables('raw', 'raw.stock_eod_yfinance', 'registered_table_schemas.yaml')    
# # print(x.get_struct_type('raw.stock_eod_yfinance'))      
# # print(x.get_create_table_script()) 