In [1]:
# IMPORTS

import pandas as pd
import numpy as np
import json
import inspect
import networkx as nx

In [2]:
# Serealize result to json
class ObjectEncoder(json.JSONEncoder):
    def default(self, obj):
        if hasattr(obj, "to_json"):
            return self.default(obj.to_json())
        elif hasattr(obj, "__dict__"):
            d = dict(
                (key, value)
                for key, value in inspect.getmembers(obj)
                if not key.startswith("__")
                and not inspect.isabstract(value)
                and not inspect.isbuiltin(value)
                and not inspect.isfunction(value)
                and not inspect.isgenerator(value)
                and not inspect.isgeneratorfunction(value)
                and not inspect.ismethod(value)
                and not inspect.ismethoddescriptor(value)
                and not inspect.isroutine(value)
            )
            return self.default(d)
        return obj

In [3]:
# GLOBAL FUNCTIONS

# Read data with given table.table_name as parquet file
def get_train_data(table):  
    df_pqt = spark.table(table['name'])
    return df_pqt
  
def get_test_data(table):  
    df_pqt = spark.table(table['name'])
    return df_pqt
  
# Read schema_discovery json object located in given path              
def get_schema_discovery(schema_discovery_file_location): 
    spark.conf.set("fs.azure.account.key.homecredittest01.blob.core.windows.net", dbutils.secrets.get(scope = "blobs", key = "hcaccesskey"))
    dbutils.fs.cp(schema_discovery_file_location, "/dbfs/tmp/schema_discovery.json")

    with open("/dbfs/tmp/schema_discovery_with_entities_and_dependencies.json", 'r', encoding='utf-8') as f:
        schema_discovery = json.load(f)
        
    return schema_discovery

In [4]:
# Read schema_discovery object    
schema_discovery_file_location = "wasbs://hc-test-data-01@homecredittest01.blob.core.windows.net/hc-test-01/schema_discovery.json"
schema_discovery = get_schema_discovery(schema_discovery_file_location)

In [5]:
def sort_table_by_degree(schema_discovery):
    g = nx.DiGraph()
    table_dep_dict = {}
    for dep in schema_discovery['dependencies']:
        table_name_1 = dep['left']['tableName']
        table_name_2 = dep['right']['tableName']
        col_name_1 = dep['left']['columnName']
        col_name_2 = dep['right']['columnName']
        relationship_col_1 = dep['left']['relationshipType']
        relationship_col_2 = dep['right']['relationshipType']
        
        if ((relationship_col_1 == 'Equel') & (relationship_col_2 == 'Equel') | (relationship_col_1 == 'Overlap') & (relationship_col_2 == 'Overlap')):
            g.add_edge((table_name_1, col_name_1), (table_name_2, col_name_2))
            g.add_edge((table_name_2, col_name_2), (table_name_1, col_name_1))
        
        elif (relationship_col_1 == 'Contains') & (relationship_col_2 == 'Contained'):
            g.add_edge((table_name_1, col_name_1), (table_name_2, col_name_2))
            if (table_name_2, col_name_2) not in table_dep_dict:
                table_dep_dict[(table_name_2, col_name_2)] = []               
            table_dep_dict[(table_name_2, col_name_2)].append((table_name_1, col_name_1))
        
        elif (relationship_col_2 == 'Contained') & (relationship_col_1 == 'Contains'):
            g.add_edge((table_name_2, col_name_2), (table_name_1, col_name_1))
            if (table_name_1, col_name_1) not in table_dep_dict:
                table_dep_dict[(table_name_1, col_name_1)] = []               
            table_dep_dict[(table_name_1, col_name_1)].append((table_name_2, col_name_2))
        
        else:
            print('NOT FOUND ' + table_name_1 + ' ' + table_name_2 + ' ' + relationship_col_1 + ' ' + relationship_col_2)
            
    in_degree_dict = {k: v for k, v in sorted(dict(g.in_degree).items(), key=lambda item: item[1])}
    return in_degree_dict, table_dep_dict      
    
def train_gan(data):
    return None
  
def predict_gan(data):
    return data
  
def train_condintional_gan(data, condition_data):
    return None
  
def predict_conditional_gan(data, condition_data):
    return data
  
def get_table_by_name(schema_discovery, table_name):
    for table in schema_discovery.tables:
        if table['name'] == table_name:
            return table
    return None
  
def get_entity_data(data, entity):
    col_list = [col['name'] for col in entity['columns']]
    return data.select(col_lost)

  
def train_per_entity_without_condition(table, data_models_per_entity): 
    data = get_train_data(table)
    for entity in table['entities']:
        entity_data = get_entity_data(data, entity)
        model = train_gan(entity_data)
        entity_signture = '_'.join([col['name'] for col in entity['columns']].sort())
        data_models_per_entity[entity_signture] = model
        
def train_per_entity_with_condition(schema_discovery, table_name, col_name, data_models_per_entity, table_dep_dict):
    all_cond_data = []
    table = get_table_by_name(schema_discovery, table_name)
    data = get_train_data(table)
    for cond_table_name in table_dep_dict[(table_name, col_name)]:
        cond_table = get_table_by_name(schema_discovery, cond_table_name)
        cond_data = get_train_data(cond_table)
        all_cond_data.append(cond_data)
    for entity in table['entities']:
        entity_data = get_entity_data(data, entity)
        model = train_condintional_gan(entity_data, all_cond_data)
        entity_signture = '_'.join([col['name'] for col in entity['columns']].sort())
        data_models_per_entity[entity_signture] = model
        
        
def train_per_table(schema_discovery):
    data_models_per_table = {}
    in_degree_dict, table_dep_dict  = sort_table_by_degree(schema_discovery)
    for table_name, col_name in in_degree_dict:
        data_models_per_entity = {} 
        condition_flag = False
        if in_degree_dict[(table,name, col_name)] == 0:
            table = get_table_by_name(schema_discovery, table_name)
            train_per_entity_without_condition(table, data_models_per_entity)       
                
        else:
            train_per_entity_with_condition(schema_discovery, table_name, col_name, table_dep_dict, data_models_per_entity)
            condition_flag = True
        
        data_models_per_table[table_name] = {'models': data_models_per_entity, 'condition_flag': condition_flag}
        
    for table in schema_discovery['tables']:
        if table['name'] not in data_models_per_table:
            data_models_per_entity = {}
            train_per_entity_without_condition(table, data_models_per_entity)
            data_models_per_table[table_name] = {'models': data_models_per_entity, 'condition_flag': False}
            
    return data_models_per_table

  
def predict_per_table(data_models_per_table, schema_discovery):
    predicted_data_per_table = {}
    for table_name in data_models_per_table:
        if data_models_per_table[table_name]['condition_flag']:         
            condition_data = cond_data_per_table[table_name]
            predicted_data = predict_conditional_gan(data, condition_data)
        else:
            predicted_data = predict_gan(data)
          

data_models_per_table = train_per_table(schema_discovery)