In [1]:
# NEED TO INSTALL THIS PACKAGES BEFORE RUN

In [2]:
#%sh
#pip install python-louvain
#pip install sqlalchemy

In [3]:
# IMPORTS

from pyspark.sql import SQLContext
from  pyspark.sql.functions import input_file_name, concat, col, lit, countDistinct, when, count, isnan, length
from pyspark.sql.types import StringType
from sqlalchemy import create_engine, MetaData
import sqlalchemy
import pandas as pd
import numpy as np
import threading
import json
import inspect
import itertools
from sklearn import preprocessing
from sklearn import metrics
import networkx as nx
import community
import seaborn as sns
from dateutil.parser import parse

In [4]:
# Serealize result to json
class ObjectEncoder(json.JSONEncoder):
    def default(self, obj):
        if hasattr(obj, "to_json"):
            return self.default(obj.to_json())
        elif hasattr(obj, "__dict__"):
            d = dict(
                (key, value)
                for key, value in inspect.getmembers(obj)
                if not key.startswith("__")
                and not inspect.isabstract(value)
                and not inspect.isbuiltin(value)
                and not inspect.isfunction(value)
                and not inspect.isgenerator(value)
                and not inspect.isgeneratorfunction(value)
                and not inspect.ismethod(value)
                and not inspect.ismethoddescriptor(value)
                and not inspect.isroutine(value)
            )
            return self.default(d)
        return obj
      
# Stores the connection parameters
class Connection:
    def __init__(self,sqlc, jdbcUsername, jdbcPassword, jdbcDriver, jdbcHostname, jdbcDatabase, port):
      self.sqlc = sqlc
      self.jdbcUsername = jdbcUsername
      self.jdbcPassword = jdbcPassword
      self.jdbcDriver = jdbcDriver
      self.jdbcHostname = jdbcHostname
      self.jdbcDatabase = jdbcDatabase
      self.port = port
      self.jdbcUrl = "jdbc:postgresql://{0}/{1}".format(jdbcHostname, jdbcDatabase)
      
# SchemaDiscoveryDTO stores a list of tables in the input connection (tables of type Table) and list of dependencies between 
# the tables
class SchemaDiscoveryDTO:
    def __init__(self, name):
        self.name = name
        self.tables = []
        self.dependencies = None
        
# TableDTO contains a name, schema name (e.g. public), list of columns and list of entities (each entity is list of columns)
class TableDTO:
    def __init__(self, table_name):       
        self.name = table_name
        self.columns = []  
        self.entities = None
        self.size = None
        
# A column contains column name, raw_type(int, float, object) and type(label/free text/ numeric/ timestamp/ identifier code)
class ColumnDTO:
    def __init__(self, col_name, col_raw_type, col_type, unique_cnt, not_null_cnt, min_digits, max_digits, is_pk, pk_source, is_fk, fk_source):
        self.name = col_name
        self.rawType = col_raw_type
        self.type = col_type
        self.uniqueCount = unique_cnt
        self.notNullCount = not_null_cnt
        self.minDigits = min_digits
        self.maxDigits = max_digits
        self.isPK = is_pk
        self.PKsource = pk_source
        self.isFK = is_fk
        self.FKsource = fk_source
        
        
# A Entity contains a list of columns 
class EntityDTO:
    def __init__(self):
        self.columns = []
        
# TableRefDTO contain name of the table, name of the column, its cardinality and its relationship
class TableRefDTO:
    def __init__(self, table_name, column_name, key_type, cardinality_type, relationship_type):
        self.tableName = table_name
        self.columnName = column_name
        self.keyType = key_type
        self.cardinalityType = cardinality_type
        self.relationshipType = relationship_type

# DependencyDTO contain two TableRefDTO; left and right that represent the dependency between two columns and the source of the dependency (from metadata or founde by us)
class DependencyDTO:
    def __init__(self, table_left, table_right, dependency_source):
        self.left = table_left
        self.right = table_right
        self.dependencySource = dependency_source

In [5]:
# GLOBAL FUNCTIONS

# Read data with given table.table_name as parquet file and sample sample_size records if to_sample == True
def get_data(table, sample_size, to_sample=True):  
    df_pqt = spark.table(table.name)
    
    if to_sample:
        table_size = table.size
        how_many_take = min(table_size, sample_size)
        sampeld_records = df_pqt.sample(fraction=0.1 + 1.0*how_many_take/(table_size*0.9), withReplacement=False).limit(how_many_take)

        return sampeld_records
    
    return df_pqt
  
# Save schema_discovery object in given location
def save_schema_discovery(schema_discovery, schema_discovery_file_location):
    with open('/dbfs/tmp/schema_discovery_with_entities_and_dependencies.json', 'w', encoding='utf-8') as f:
        json.dump(schema_discovery, cls=ObjectEncoder, indent=2, fp=f, ensure_ascii=False)
    
    dbutils.fs.cp("dbfs:/tmp/schema_discovery_with_entities_and_dependencies.json", schema_discovery_file_location)

In [6]:
# STEP 0 - ETL - READ DATA FROM GIVEN DATABASE OR CSV FILES AND SAVE LOCALY AS PARQUET FILES

class ETL:
    schema_discovery = None
    metadata = None
    
    # Create connection with given connection details
    def get_connection(self, jdbcUsername, jdbcPassword, jdbcDriver, jdbcHostname, jdbcDatabase, port):
        sqlc = SQLContext(sc)
        spark.conf.set("fs.azure.account.key.homecredittest01.blob.core.windows.net", dbutils.secrets.get(scope = "blobs", key = "hcaccesskey"))
        connection = Connection(sqlc, jdbcUsername, jdbcPassword, jdbcDriver, jdbcHostname, jdbcDatabase, port)    
        return connection

    def get_metadata(self, connection):
        # construct an engine connection string
        engine_string = "postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}".format(user = connection.jdbcUsername,
                                                                                                  password = connection.jdbcPassword,
                                                                                                  host = connection.jdbcHostname,
                                                                                                  port = connection.port,
                                                                                                  database = connection.jdbcDatabase,)

        # create sqlalchemy engine
        engine = create_engine(engine_string)
        metadata = None
        try:
            metadata = MetaData(bind=engine, reflect=True)
        except:
            print('cant read metadata from DB')

        self.metadata = metadata

    # Read data with table.name from given connection and save it as parquet file
    # sample_percentage can be number between 0-1 (the percentage of data to sample from the original data)
    def convert_table_from_DB_to_pqt(self, connection, table, sample_percentage):
        if sample_percentage == 1:
            sql_str = '(SELECT * FROM ' + self.schema_discovery.name+'.'+table.name + ') a'
        else:
            sql_str = '(SELECT * FROM ' + self.schema_discovery.name+'.'+table.name + ' where random() <= ' + str(sample_percentage) + ') a'

        data_db = connection.sqlc.read.format("jdbc")\
                            .option("driver", connection.jdbcDriver)\
                            .option("url", connection.jdbcUrl)\
                            .option("dbtable", sql_str)\
                            .option("user", connection.jdbcUsername)\
                            .option("password", connection.jdbcPassword)\
                            .option("sslmode","require")\
                            .option("numPartitions",4)\
                            .option("fetchsize",1000)\
                            .load()

        data_db.write.mode("overwrite").saveAsTable(table.name) #.option("path",schema_discovery.name)
        spark.sql("REFRESH TABLE " + table.name)
        data_df = spark.table(table.name)
        table.size  = data_df.count()


    # Get all tables names from table with all tables details and for each table:
    # 1. read the table and save it as parquet file using convert_table_from_DB_to_pqt function
    # 2. save their names and size in schema_discovery.tables_list
    def create_schema_discovery_from_DB(self, jdbcUsername, jdbcPassword, jdbcDriver, jdbcHostname, jdbcDatabase, port, sample_percentage):
        # Create connection to DB  
        connection = self.get_connection(jdbcUsername, jdbcPassword, jdbcDriver, jdbcHostname, jdbcDatabase, port)
        
      
        df_tables = connection.sqlc.read.format("jdbc")\
                              .option("driver", connection.jdbcDriver)\
                              .option("url", connection.jdbcUrl)\
                              .option("dbtable", "pg_stat_user_tables")\
                              .option("user", connection.jdbcUsername)\
                              .option("password", connection.jdbcPassword)\
                              .option("sslmode","require")\
                              .option("numPartitions",4)\
                              .option("fetchsize",1000)\
                              .load()

        df_tables = df_tables[['schemaname','relname']].toPandas()
        schema_neme = df_tables['schemaname'].values[0]
        self.schema_discovery = SchemaDiscoveryDTO(schema_neme)

        threads_list = []
        for table_details in df_tables.itertuples():
            table = TableDTO(table_details.relname)
            self.schema_discovery.tables.append(table)
            thread = threading.Thread(target=self.convert_table_from_DB_to_pqt, args=(connection, table, sample_percentage,))
            threads_list.append(thread)
            thread.start()

        for thread in threads_list:
            thread.join()
            
        self.get_metadata(connection)
        
        return self.schema_discovery, self.metadata

    # Read data with table.table_name from given path and save it as parquet file
    # sample_percentage can be number between 0-1 (the percentage of data to sample from the original data)
    def convert_table_from_csv_to_pqt(self, path, table, sample_percentage):
        data_file = spark.read.csv(path + table.name + '.csv', header='true', inferSchema='true')
        if sample_percentage != 1:
            data_file = data_file.sample(withReplacement=False, fraction=sample_percentage)

        #Remove illegal characters in column names
        new_column_name_list= list(map(lambda x: x.replace(" ", "_").replace("(", "").replace(")", "_")
                                                  .replace(".", "_").replace(",", "_").replace("{", "_").replace("}", "_"), data_file.columns))
        data_file = data_file.toDF(*new_column_name_list)
        data_file.write.mode("overwrite").option("path", self.schema_discovery.name).saveAsTable(table.name) 
        spark.sql("REFRESH TABLE " + table.name)
        table.size  = spark.table(table.name).count()

    # Get all file names from the given folder (e.g. path = 'wasbs://hc-test-data-01@homecredittest01.blob.core.windows.net/hc-test-01/'), 
    # for each file:
    # 1. read the file (table) and save it as parquet file using convert_table_from_DB_to_pqt function
    # 2. save their names and size in schema_discovery.tables_list
    def create_schema_discovery_from_csv(self, path, sample_percentage):
        df_files = spark.read.format("csv").load(path).select(input_file_name()).distinct().toPandas()['input_file_name()']
        #Filter csv files only
        df_files = [ filename for filename in df_files if filename.endswith( 'csv' ) ]

        #Extract file name and (last) directory name as table and schema correspondingly
        res = [ (path.split('/')[-2] , path.split('/')[-1].split('.')[-2]) if path.index('.')> 0 else (path.split('/')[-2],path.split('/')[-1]) for path in df_files ]
        schemas,fnames = zip(*res)
        df_tables = pd.DataFrame(data={'schemaname': schemas, 'relname': fnames})

        schema_name = df_tables['schemaname'].values[0]
        self.schema_discovery = SchemaDiscoveryDTO(schema_name)

        threads_list = []
        for table_details in df_tables.itertuples():
            table = Table(table_details.relname)
            self.schema_discovery.tables.append(table)
            thread = threading.Thread(target=self.convert_table_from_csv_to_pqt, args=(path, table, sample_percentage,))
            threads_list.append(thread)
            thread.start()

        for thread in threads_list:
            thread.join()
            
        return self.schema_discovery, None

In [7]:
# STEP 1 - EXTRACT COLUMNS TYPES

class ColumnsInfoExtractor:

    def __init__(self, schema_discovery, metadata):
        self.schema_discovery = schema_discovery
        self.metadata = metadata
    
    # Check if given string is in date format
    def is_date(self, string):
        try:
            parse(string, fuzzy=False)
            return True
        except:
            return False

    # An auxiliary function that accepts basic columns_data (name & raw data type per column) and a sample of records (based on these columns) 
    # and calculates the final column type (label/free text/ numeric/ timestamp/ identifier code)
    def calc_col_types(self, columns_data, data, table_metadata):   
        # Calculates data types from sampled_records and updating columns_data accordingly
        data_size = data.count() # the maximal number of unique values
        long_str = 200 # A column containing strings with more than this number of characters will be considered free text column
        label_threshold = 0.2 # labels are expected to contain unique values up to this percent of the number of non empty values
        text_threshold = 0.8 # Free texts are expected to contain unique values of at least this percent of the number of non empty values
        idenifier_threshold = 4 # Free texts are expected to contain at least this number of digits

        res = columns_data.copy()
        res['raw_type'] = res['raw_type'].astype(str)

        # Add number of unique values per column
        res['unique_vals'] = [val for val in data.agg(*(countDistinct(c).alias(c) for c in data.columns)).collect()[0]]

        # Add number of non-null values per column
        res['not_null_cnt'] = [val for val in data.select([count(c).alias(c) for c in data.columns]).collect()[0]]

        res['min_digits'] = 0 
        res['max_digits'] = 0 
        res['is_bool'] = [False] * len(res)
        res['is_str'] = [False] * len(res)
        res['is_date'] = [False] * len(res)
        for col in data.schema.names:
            col_length = data.withColumn('len', length(col)).select('len')
            max_str_len = col_length.agg({"len": "max"}).collect()[0]["max(len)"]
            min_str_len = col_length.agg({"len": "min"}).collect()[0]["min(len)"]
            res.loc[res['col_name']==col, 'max_digits'] = max_str_len
            res.loc[res['col_name']==col, 'min_digits'] = min_str_len

            col_vals_as_str = [val[0] for val in data.select(col).dropna().withColumn(col, data[col].cast(StringType())).sample(False, 0.1).limit(1000000).collect()]
            # check if all values are None
            if len(col_vals_as_str) == 0:
                continue

            if ('True' in col_vals_as_str) | ('False' in col_vals_as_str):
                res.loc[res['col_name']==col, 'is_bool'] = True

            if not any(any(sub_str.isdigit() for sub_str in main_str) for main_str in col_vals_as_str):
                res.loc[res['col_name']==col, 'is_str'] = True

            date_check = [self.is_date(val) for val in col_vals_as_str]    
            if all(date_check):
                res.loc[res['col_name']==col, 'is_date'] = True    

        res['col_type'] = res['raw_type'].copy()

        # Checks if the column is numeric
        is_numeric = np.vectorize(lambda x: True if any(sub_str in str(x) for sub_str in ['int', 'float']) else False)
        res['is_numeric'] = is_numeric(res['raw_type']) 

        # Checks if the column is int
        is_int = np.vectorize(lambda x: True if 'int' in str(x) else False)
        res['is_int'] = is_int(res['raw_type'])

        # Recognizing identifing keys as unique numbers or codes
        msk_id_code = ((res['is_str']==False) & (res['is_bool']==False) & (res['min_digits'] == res['max_digits']) & (res['min_digits']>=idenifier_threshold) &\
                       ((res['is_int']==True) | (res['raw_type'] == 'object')))
        res.loc[msk_id_code,'col_type'] = 'identifier code' 

        # Recognizing labels as repeating short texts
        msk_lable = ((res['is_bool']==False) & (res['col_type'] == res['raw_type']) & \
                     (res['unique_vals']<res['not_null_cnt']*label_threshold) & (res['raw_type'] == 'object'))
        res.loc[msk_lable,'col_type'] = 'label' 

        # Recognizing numeric columns as numeric columns that are not identifiers
        msk_numeric = (res['is_bool']==False) & (res['col_type'] == res['raw_type']) & (res['is_numeric'])
        res.loc[msk_numeric,'col_type'] = 'numeric'

        # Recognizing free texts as unique or long texts
        msk_free_text = ((res['is_bool']==False) & (res['col_type'] == res['raw_type']) & \
                         ((res['is_str']==True) | (((res['unique_vals']>=res['not_null_cnt']*text_threshold) | (res['max_digits']>=long_str)) \
                                                   & (res['raw_type'] == 'object'))))
        res.loc[msk_free_text,'col_type'] = 'free text' 

        # Recognizing bool columns as bool columns
        msk_bool = (res['is_bool'])
        res.loc[msk_bool,'col_type'] = 'bool'

        # Recognizing dates
        msk_date = ((res['raw_type'].str.contains('time')) | (res['is_date'] == True))
        res.loc[msk_date,'col_type'] = 'timestamp'

        # Recognizing Primery Key
        res['is_pk'] = False
        res['pk_source'] = 'NONE'
        if table_metadata != None:
            for col in [col for col in table_metadata.columns if ((col.name in columns_data['col_name'].values) & (col.primary_key))]:
                res.loc[res['col_name'] == col.name, 'is_pk'] = True
                res.loc[res['col_name'] == col.name, 'pk_source'] = 'Metadata'

        else:
            msk_pk = ((res['col_type']=='identifier code') & (res['unique_vals'] == res['not_null_cnt']) & (res['unique_vals'] == data_size))
            res.loc[msk_pk, 'is_pk'] = True 
            res.loc[msk_pk, 'pk_source'] = 'Discovered' 

            num_pk = res.loc[res['is_pk'] == True].shape[0]

            # If no PK, search for combinations of identifier code columns that create x unqiue keys where x is the size of the data
            if num_pk == 0:
                id_code_cols = res.loc[(res['col_type'] == 'identifier code') & (res['not_null_cnt'] == data_size)]['col_name'].values
                for pk_len in range(2, len(id_code_cols)+1):
                    all_possible_combination = itertools.combinations(id_code_cols, pk_len)
                    for pk_group in all_possible_combination:
                        # check if for each possible pk_group with unique values we have only one record --> PK
                        num_unqiue_vals = data.groupBy(list(pk_group)).count().count()
                        if num_unqiue_vals == data_size:
                            res.loc[res['col_name'].isin(pk_group), 'is_pk'] = True
                            res.loc[res['col_name'].isin(pk_group), 'pk_source'] = 'Discovered'
                            pk_len = len(id_code_cols)+1
                            break 

            # If more then one column acts as PK (by itself, not as combination of columns) we need to choose only one of 
            # them as PK and the others will be regular identifier code columns (there is no additional information in keeping all as PK)
            elif num_pk > 1:
                all_pk_cols = res.loc[res['is_pk'] == True]['col_name']
                res.loc[res['col_name'].isin(all_pk_cols[1:]), 'is_pk'] = False
                res.loc[res['col_name'].isin(all_pk_cols[1:]), 'pk_source'] = 'NONE'

        return res


    # Extracting the list of columns with column names and their raw data types from table with given table.name
    def get_columns(self, table, table_metadata): 
        data = get_data(table, 0, to_sample=False)

        cols_dtype = data.dtypes
        col_names = [val[0] for val in cols_dtype]
        col_types = [val[1] for val in cols_dtype]
        cols_df = pd.DataFrame({'col_name':col_names, 'raw_type':col_types})

        cols_data = self.calc_col_types(cols_df, data, table_metadata)

        table.columns = []
        for idx, col in cols_data.iterrows():
            # filter free text columns        
            if str(col['col_type']) != 'free text':
                table.columns.append(ColumnDTO(col['col_name'], str(col['raw_type']), str(col['col_type']), col['unique_vals'], col['not_null_cnt'],
                                               col['min_digits'], col['max_digits'], col['is_pk'], col['pk_source'], False, 'NONE'))

    # Extracting columns details for each table at schema_discovery by calculations made on a given data of each table              
    def extract_columns_info(self):    
        threads_list = []
        for table in self.schema_discovery.tables:
            table_metadata = None
            if self.metadata != None:
                table_metadata = self.metadata.tables[table.name]

            thread = threading.Thread(target=self.get_columns, args=(table, table_metadata,))
            threads_list.append(thread)
            thread.start()

        for thread in threads_list:
            thread.join()

In [8]:
# STEP 2 - EXTRACT ENTTIES

class EntitiesExtractor:
  
    def __init__(self, schema_discovery):
        self.schema_discovery = schema_discovery

    # Transform all columns values to categorical values (e.g. from ['israel', 'usa', 'israel', 'spain'] to [1, 2, 1, 3])
    def to_categorical(self, data):
        le = preprocessing.LabelEncoder()
        categorical_cols = {}
        for col_name in data.schema.names:
            col_real_vals = [x[col_name] for x in data.select(col_name).collect()]
            col_real_vals = ['None' if val is None else val for val in col_real_vals]
            categorical_cols[col_name] = le.fit_transform(col_real_vals)
        return categorical_cols

    # Calculate correlation matrix; extracts internal dependencies between column pairs that appear in data
    # Matrix corr_matrix contain the dependencies estimation between col_1 and col_2 (where col_1 != col_2) using mutual information measure
    def get_correlation_matrix(self, table, data, results_dict):    
        categorical_cols = self.to_categorical(data)
        data_col_names = data.schema.names
        all_possible_cols_combinations = [x for x in itertools.combinations(data_col_names, 2)]

        corr_matrix = pd.DataFrame(columns=data_col_names, 
                                   index=data_col_names, 
                                   data=np.zeros((len(data_col_names), len(data_col_names))))

        for col_tuple in all_possible_cols_combinations:
            col_1 = col_tuple[0]
            col_2 = col_tuple[1]

            cat_col_1 = categorical_cols[col_1]
            cat_col_2 = categorical_cols[col_2]

            # Calculate the Information Gain of target columns given the source column
            mutulal_info = metrics.normalized_mutual_info_score(cat_col_1, cat_col_2)
            corr_matrix.loc[col_1, col_2] = mutulal_info
            corr_matrix.loc[col_2, col_1] = mutulal_info

        # # plot heatmap of correlation matrix
        # sns.heatmap(corr_matrix, xticklabels=corr_matrix.columns, yticklabels=corr_matrix.columns, annot=True)

        results_dict[table.name] = corr_matrix

    # Create network graph using networkx package based on correlation matrix created by get_correlation_matrix function on given data
    # when filter_col is True -> correlation value is set to 0 for columns tuple with correlation value lower then mean correlation value 
    def get_network_graph(self, corr_matrix, filter_col=True):
        links = corr_matrix.stack().reset_index()
        links.columns = ['var1', 'var2', 'value']

        # Remove self correlation
        links_filtered = links.loc[links['var1'] != links['var2']]

        if filter_col:
            # Keep only correlation over a threshold (the mean correlation value)
            mean_corr = links_filtered.loc[links_filtered['value'] > 0]['value'].mean()
            links_filtered = links_filtered.loc[links_filtered['value'] > mean_corr]

        # Build the graph
        G = nx.from_pandas_edgelist(links_filtered, 'var1', 'var2')

        # # Plot the network:
        # nx.draw_circular(G, with_labels=True, node_size=200, font_size=10)

        return G

    # Detects communities in the graph
    # The communities detection method we use can be one of the next communities detection methods:
    # best_partition (based on Louvain algorithm)
    def detect_communities(self, G):
        communities = community.best_partition(G)

        clusters = []
        for i in range(len(G.nodes)):
            curr_cluster = [key for key, val in communities.items() if val == i]
            if len(curr_cluster) == 0:
                break
            clusters.insert(i, curr_cluster)

        return clusters

    # Plot the entities we found as communities on network (features) graph
    def plot_communty_network(self, G, path_to_save_plot, with_labels=True):
        plt.figure(figsize=(8,8))
        pos = nx.spring_layout(G, k=2)
        node_colors = ['green', 'red', 'yellow', 'black', 'blue', 'orange', 'pink', 'purple', 'gray', 'brown']

    #     edge_labels = nx.get_edge_attributes(G, 'value')
    #     for key_tuple in edge_labels:
    #       edge_labels[key_tuple] = round(edge_labels[key_tuple], 3)

    #     nx.draw_networkx_edges(G, pos, width=1, alpha=0.9)
    #     nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels)

        edges = G.edges()
        weights = [G[u][v]['value']*2 for u,v in edges]
        nx.draw_networkx_edges(G, pos, width=weights, alpha=0.9)

        for node, node_att_dict in G.nodes(data=True):
            cluster_idx = node_att_dict['cluster']
            color = node_colors[cluster_idx%len(node_colors)]
            nx.draw_networkx_nodes(G, pos, nodelist=[node], node_color=color, node_size=150)
            if with_labels:
                nx.draw_networkx_labels(G, pos, {node:node}, font_size=8)

        nx.write_gexf(G, path_to_save_plot + '.gexf')
        plt.savefig(path_to_save_plot)  
        plt.show()

    # Main function that uses all previous functions to detect main entities for each table in schema_discovery
    def find_entities(self, sample_size):
        threads_list = []
        results_dict = {}
        for table in self.schema_discovery.tables:
            data = get_data(table, sample_size)
            thread = threading.Thread(target=self.get_correlation_matrix, args=(table, data, results_dict,))
            threads_list.append(thread)
            thread.start()

        for thread in threads_list:
            thread.join()

        for table in self.schema_discovery.tables:
            corr_matrix = results_dict[table.name]
            G = self.get_network_graph(corr_matrix)
            clusters = self.detect_communities(G)
    #         self.plot_communty_network(G, '/dbfs/tmp/' + table.name)
            print('finish finding clusters for table ' + table.name + ', num of clusters: ' + str(len(clusters)))

            table.entities = []
            for cluster in clusters:
                entity = EntityDTO()
                entity.columns = cluster
                table.entities.append(entity)

In [9]:
# STRP 3 - EXTRACT EXTERNAL DEPENDENCIES

class DependenciesExtractor:
  
    def __init__(self, schema_discovery, metadata):
        self.schema_discovery = schema_discovery
        self.metadata = metadata

    # Concatinate number of columns to one unique PK column - done for compersion between PK of tables that build from number
    # of columns
    def combain_pk_columns(self, data, pk_list):
        data = data.withColumn('concat_pk_col', concat(col(pk_list[0]), lit('_'), col(pk_list[1])))
        for i in range(2, len(pk_list)):
            data = data.withColumn('concat_pk_col', concat(col(concat_pk_col), lit('_'), col(pk_list[i])))

        pk_raw_types = [col.rawType for col in pk_list]
        pk_raw_types.sort()  
        new_pk_list = [ColumnDTO(name='concat_pk_col', 
                                 rawType='_'.join(pk_raw_types), 
                                 type='identifier code',
                                 isPK=True,
                                 PKsource='Discovered',
                                 isFK=False,
                                 FKsource='NONE')]

        return data, new_pk_list


    # Get the cardinality of column given the data. cardinality can get one of the next values:
    # '1': each unique value of this column appear only once
    # 'M': there is at least one unqiue value of this column that appear more then one time
    def get_cardinality(self, data, col_name):
        map_cardinalty = lambda x : 'One' if (x == 1) else 'Many' 

        counter_col = data.dropna(subset=[col_name]).groupby(col_name).count()
        cardinality = map_cardinalty(counter_col.agg({"count": "max"}).collect()[0]["max(count)"])

        return cardinality


    # Search table object with given table_name in schema_discovery.tables_list. If there is no table with this name- return None
    def get_table_by_name(self, table_name):
        for table in self.schema_discovery.tables:
            if table.name == table_name:
                return table
        return None


    # Change the type of columns with given col_name to 'FK' type
    def set_col_keytype_to_FK(self, table, col_name, FK_source):
        for col in table.columns:
            if col.name == col_name:
                col.isFK = True
                col.FKsource = FK_source
                if col.isPK:
                    return 'PK'
                else:
                    return 'FK'

        return 'NONE'

    def get_unique_values(self, data, col_name):
        col_type = data.schema[col_name].dataType
        col_unqiue_vals = data.select(col_name).dropna().distinct()
        df_col_vals = spark.createDataFrame(col_unqiue_vals, col_type)        


    # Get the relationship between two columns; Contained, Contains, Partial Overlap, Equal (same values), or None if there is
    # no relationship between this two columns
    def get_relationship_type(self, data_1, data_2, col_1, col_2):
        col_name_1 = col_1 + '_1'
        col_name_2 = col_2 + '_2'                

        vals_col_1 = data_1.select(col_1).dropna().distinct().withColumnRenamed(col_1, col_name_1)
        vals_col_2 = data_2.select(col_2).dropna().distinct().withColumnRenamed(col_2, col_name_2)

        col_1_relationship = None
        col_2_relationship = None

        join_on_cols = vals_col_1.join(vals_col_2, vals_col_1[col_name_1] == vals_col_2[col_name_2], how='inner')

        join_size = join_on_cols.count()
        col_1_unq_size = vals_col_1.count()
        col_2_unq_size = vals_col_2.count()

        if (join_size == col_1_unq_size) & (join_size == col_2_unq_size):
            col_1_relationship = 'Equal'
            col_2_relationship = 'Equal'
        else:
            if join_size == col_1_unq_size:
                col_1_relationship = 'Contained'
                col_2_relationship = 'Contains'
            elif join_size == col_2_unq_size:
                col_1_relationship = 'Contains'
                col_2_relationship = 'Contained'
            elif join_size > 0:
                col_1_relationship = 'Overlap'
                col_2_relationship = 'Overlap'

        return col_1_relationship, col_2_relationship
      
    # Get all PK columns and identifier code columns of given table
    def get_pk_id_cols(self, table):
        pk_table = [col for col in table.columns if col.isPK]
        identifier_code_cols = [col for col in table.columns if ((col.type == 'identifier code') & (not col.isPK) & (col.FKsource != 'Internal'))]
        
        return pk_table, identifier_code_cols

    # Extracts external dependencies between tabels pairs
    # dataframe connected_tabels contains names of the two connected attributes, their tabels names, and the type of 
    # dependency (1:1, 1:M, M:1, M:N)
    def find_foreign_keys(self, table_1, table_2, tables_dependencies_dict, idx):      
        data_1 = get_data(table_1, 0, to_sample=False)
        data_2 = get_data(table_2, 0, to_sample=False)

        pk_table_1, identifier_code_cols_1 = self.get_pk_id_cols(table_1)
        pk_table_2, identifier_code_cols_2 = self.get_pk_id_cols(table_2)

        if (len(pk_table_1) == 0) & (len(pk_table_2) == 0):
            return

        # If the tables have diffrent size of PK we cant comper bwtween them -> check dependencies only between PK of size 1 and 
        # identifier code columns
        if (len(pk_table_1) != len(pk_table_2)):
            if (len(pk_table_1) > 1) & (len(pk_table_2) > 1):
                return 
            elif len(pk_table_1) > 1:
                pk_table_1 = []
            elif len(pk_table_2) > 1:
                pk_table_2 = []

        # If the tables have same size of PK then we can comper them- first we concatinate all pk columns to one unique column
        if (len(pk_table_1) == len(pk_table_2)) & (len(pk_table_1) > 1):
            data_1, pk_table_1 = self.combain_pk_columns(data_1, pk_table_1)
            data_2, pk_table_2 = self.combain_pk_columns(data_2, pk_table_2)    

        pk_identifier_code_cols_1 = pk_table_1 + identifier_code_cols_1
        pk_identifier_code_cols_2 = pk_table_2 + identifier_code_cols_2

        for col_1 in pk_identifier_code_cols_1:
            for col_2 in pk_identifier_code_cols_2:
                if (not col_1.isPK) & (not col_2.isPK):
                    continue

                if col_1.rawType == col_2.rawType:
                    col_1_relationship, col_2_relationship = self.get_relationship_type(data_1, data_2, col_1.name, col_2.name)

                    if col_1_relationship != None:
                        key_type_1 = 'PK'
                        key_type_2 = 'PK'
                        if (col_1_relationship == 'Contained') | (col_1_relationship == 'Overlap') | (col_1_relationship == 'Equal'):
                            key_type_1 = self.set_col_keytype_to_FK(table_1, col_1.name, 'Discovered')
                        if (col_2_relationship == 'Contained') | (col_2_relationship == 'Overlap') | (col_2_relationship == 'Equal'):
                            key_type_2 = self.set_col_keytype_to_FK(table_2, col_2.name, 'Discovered')

                        cardinality_col_1 = self.get_cardinality(data_1, col_1.name)
                        cardinality_col_2 = self.get_cardinality(data_2, col_2.name)

                        ref_left = TableRefDTO(table_name=table_1.name, column_name=col_1.name, key_type=key_type_1, \
                                               cardinality_type=cardinality_col_1, relationship_type=col_1_relationship)
                        ref_right = TableRefDTO(table_name=table_2.name, column_name=col_2.name, key_type=key_type_2, \
                                                cardinality_type=cardinality_col_2, relationship_type=col_2_relationship)
                        dependency = DependencyDTO(ref_left, ref_right, dependency_source='Discovered')
                        tables_dependencies_dict[idx].append(dependency)

    #                     print(dependency.left.tableName+'.'+dependency.left.columnName+'_'+dependency.left.keyType+'_'+dependency.left.cardinalityType+'_'+\
    #                           dependency.left.relationshipType+'  '+\
    #                           dependency.right.tableName+'.'+dependency.right.columnName+'_'+dependency.right.keyType+'_'+\
    #                           dependency.right.cardinalityType+'_'+dependency.right.relationshipType+'  '+\
    #                           dependency.dependencySource)


    # Iterate over thr final dependencies_list and remove from it duplicate dependencies
    def remove_duplicate_dependencies(self, dependencies_list):
        all_dep_hash = []
        new_dependencies_list = []
        for dep in dependencies_list:
            ref_hash_left = dep.left.tableName + '_' + dep.left.columnName 
            ref_hash_right = dep.right.tableName + '_' + dep.right.columnName

            left_right = ref_hash_left + '_' + ref_hash_right
            right_left = ref_hash_right + '_' + ref_hash_left

            if (left_right not in all_dep_hash) & (right_left not in all_dep_hash):
                new_dependencies_list.append(dep)
                all_dep_hash = all_dep_hash + [left_right, right_left]

        return new_dependencies_list


    # Change all 'identifier code' type columns that didnt recognized as PK or FK to 'lable' type
    def change_identifier_code_to_lable(self):
        for table in self.schema_discovery.tables:
            for col in table.columns:
                if ((not col.isPK) & (not col.isFK) & (col.type == 'identifier code')):
                    col.type = 'lable'


    # Iterate over known dependencies and add them to dependencies_list that evntually will be returned and added to 
    # schema_discovery.dependencies
    def add_table_dependencies_from_metadata(self, table):
        table_metadata = self.metadata.tables[table.name]
        table_constraints = list(table_metadata.constraints)
        dependencies_list = []

        for i in range(len(table_constraints)):
            if type(table_constraints[i]) == sqlalchemy.sql.schema.ForeignKeyConstraint:
                table_columns_with_foreign_key = []

                # get foreign keys names as thay appear in the current table
                for col in table_constraints[i].columns:
                    table_columns_with_foreign_key.append((str(col.table.name), col.name))

                # get foreign keys names as thay appear in the foreign table
                all_foreign_keys = [(elemnt.column.table.name, elemnt.column.name) for elemnt in table_constraints[i].elements]

                if len(table_columns_with_foreign_key) != len(all_foreign_keys):
                    return dependencies_list

                # iterate over all pairs of columns we found (current table columns, foreign table colum)
                for i in range(len(table_columns_with_foreign_key)):
                    table_col = table_columns_with_foreign_key[i]
                    foreign_col = all_foreign_keys[i]

                    foreign_table = self.get_table_by_name(foreign_col[0])
                    if foreign_table == None:
                        return dependencies_list

                    table_data = get_data(table, 0, to_sample=False)
                    foreign_table_data = get_data(foreign_table, 0, to_sample=False)

                    cardinality_left = self.get_cardinality(table_data, table_col[1])
                    cardinality_right = self.get_cardinality(foreign_table_data, foreign_col[1])

                    col_1_relationship, col_2_relationship = self.get_relationship_type(table_data, foreign_table_data, table_col[1], foreign_col[1])
                    
                    key_type_1 = 'PK'
                    key_type_2 = 'PK'
                    if (col_1_relationship == 'Contained') | (col_1_relationship == 'Overlap') | (col_1_relationship == 'Equal'):
                        key_type_1 = self.set_col_keytype_to_FK(table, table_col[1], 'Metadata')
                    if (col_2_relationship == 'Contained') | (col_2_relationship == 'Overlap') | (col_2_relationship == 'Equal'):
                        key_type_2 = self.set_col_keytype_to_FK(foreign_table, foreign_col[1], 'Metadata')

                    ref_left = TableRefDTO(table_name=table_col[0], column_name=table_col[1], key_type=key_type_1, \
                                           cardinality_type=cardinality_left, relationship_type=col_1_relationship)
                    ref_right = TableRefDTO(table_name=foreign_col[0], column_name=foreign_col[1], key_type=key_type_2, \
                                            cardinality_type=cardinality_right, relationship_type=col_2_relationship)
                    dependency = DependencyDTO(ref_left, ref_right, dependency_source='Metadata')

                    dependencies_list.append(dependency)

    #                 print(dependency.left.tableName+'.'+dependency.left.columnName+'_'+dependency.left.keyType+'_'+dependency.left.cardinalityType+'_'+\
    #                       dependency.left.relationshipType+'  '+\
    #                       dependency.right.tableName+'.'+dependency.right.columnName+'_'+dependency.right.keyType+'_'+\
    #                       dependency.right.cardinalityType+'_'+dependency.right.relationshipType+'  '+\
    #                       dependency.dependencySource)

        return dependencies_list

    def find_internal_foreign_keys(self, table, internal_dependencies_dict):
        internal_dependencies_dict[table.name] = []
        data = get_data(table, 0, to_sample=False)

        pk_table, identifier_code_cols = self.get_pk_id_cols(table)

        if len(pk_table) == 0:
            return  

        for pk_col in pk_table:
            for id_col in identifier_code_cols:
                if pk_col.rawType == id_col.rawType:
                    col_1_relationship, col_2_relationship = self.get_relationship_type(data, data, pk_col.name, id_col.name)

                    if col_2_relationship == 'Contained':      
                        id_col.isFK = True
                        id_col.FKsource = 'Internal'

                        internal_dep = {'pk_col': pk_col, 'internal_fk_col':id_col}
                        tables_dependencies_dict[table.name].append(internal_dep)                        
                        print(internal_dep)

    def update_internal_fk_to_external_table(self, table_name, pk_col, internal_fk_col, all_tables_dependencies):
        for dep in all_tables_dependencies:
            pk_ref = None
            if (dep.left.tableName == table_name) & (dep.left.columnName == pk_col.name) & (dep.left.relationshipType == 'Contained'):
                pk_ref = dep.right
            elif (dep.right.tableName == table_name) & (dep.right.columnName == pk_col.name) & (dep.right.relationshipType == 'Contained'):
                pk_ref = dep.left
                
            if pk_ref != None:
                table = self.get_table_by_name(table_name)
                table_data = get_data(table, 0, to_sample=False)
                cardinality = self.get_cardinality(table_data, internal_fk_col.name)
                ref = TableRefDTO(table_name=table_name, column_name=internal_fk_col.name, key_type='FK', \
                                  cardinality_type=cardinality, relationship_type='Contained')
                
                dependency = DependencyDTO(pk_ref, ref, dependency_source='Discovered')
                all_tables_dependencies.append(dependency)

                internal_fk_col.isFK = True
                internal_fk_col.FKsource = 'Discovered'
        
    
    
    # Main function; extract dependencies between each possible combonation of two table from schema_discovery
    def get_external_dependencies(self):
        all_tables_dependencies = []
        if self.metadata != None:
            for table in self.schema_discovery.tables:
                table_dep_from_metadata = self.add_table_dependencies_from_metadata(table)
                all_tables_dependencies = all_tables_dependencies + table_dep_from_metadata

        threads_list = []
        internal_dependencies_dict = {}   
        for table in self.schema_discovery.tables:
            thread = threading.Thread(target=self.find_internal_foreign_keys, args=(table, internal_dependencies_dict,))
            threads_list.append(thread)
            thread.start() 
            
        for thread in threads_list:
            thread.join()
            
        print(internal_dependencies_dict)
                
        all_possible_table_combinations = [x for x in itertools.combinations(self.schema_discovery.tables, 2)]
        threads_list = []
        tables_dependencies_dict = {}
        for idx, tables_tupple in enumerate(all_possible_table_combinations):
            table_1 = tables_tupple[0]
            table_2 = tables_tupple[1]
            tables_dependencies_dict[idx] = []

            thread = threading.Thread(target=self.find_foreign_keys, args=(table_1, table_2, tables_dependencies_dict, idx,))
            threads_list.append(thread)
            thread.start()

        for thread in threads_list:
            thread.join()          

        for idx in range(len(all_possible_table_combinations)):
            tables_dependencies = tables_dependencies_dict[idx]
            all_tables_dependencies = all_tables_dependencies + tables_dependencies
            
        for table_name in internal_dependencies_dict:
            for key_dict in internal_dependencies_dict[table_name]:
                pk_col = key_dict['pk_col']
                internal_fk_col = key_dict['internal_fk_col']
                self.update_internal_fk_to_external_table(table_name, pk_col, internal_fk_col, all_tables_dependencies)

        all_tables_dependencies = self.remove_duplicate_dependencies(all_tables_dependencies)
        self.schema_discovery.dependencies = all_tables_dependencies
        self.change_identifier_code_to_lable()

In [10]:
# Create schema_discovery object using get_schema_discovery_from_DB function
# User input
read_from_db = True
sample_percentage = 1
jdbcUsername = 
jdbcPassword = 
jdbcDriver = 
jdbcHostname = 
jdbcDatabase = 
port = 


etl_object = ETL()
schema_discovery, metadata = etl_object.create_schema_discovery_from_DB(jdbcUsername, jdbcPassword, jdbcDriver, jdbcHostname, jdbcDatabase, port, sample_percentage)

# # Create schema_discovery object using get_schema_discovery_from_csv function
# # User input
# read_from_db = False
# sample_percentage = 1
# path = ''

# etl_object = ETL()
# schema_discovery, metadata = etl_object.create_schema_discovery_from_csv(path, sample_percentage)


# Extract columns details for each table in schema_discovery using extract_columns_info function
columns_info_extractor = ColumnsInfoExtractor(schema_discovery, metadata)
columns_info_extractor.extract_columns_info()


# Extract entities for each table in schema_discovery
sample_size = 10000 # User unput
entities_extractor = EntitiesExtractor(schema_discovery)
entities_extractor.find_entities(sample_size)

# Extract dependencies between each possible combination of two tables in schema_discovery using get_external_dependencies function
dependencies_extractor = DependenciesExtractor(schema_discovery, metadata)
dependencies_extractor.get_external_dependencies()

# Save schema_discovery object
schema_discovery_file_location = "wasbs://hc-test-data-01@homecredittest01.blob.core.windows.net/hc-test-01/schema_discovery.json" # User unput
save_schema_discovery(schema_discovery, schema_discovery_file_location)

In [11]:
# Extract dependencies between each possible combination of two tables in schema_discovery using get_external_dependencies function
dependencies_extractor = DependenciesExtractor(schema_discovery, metadata)
dependencies_extractor.get_external_dependencies()

In [12]:
for table in schema_discovery.tables:
  print(table.name)

In [13]:
for table in schema_discovery.tables:
  if table.name == 'credit_card_balance':
    break

In [14]:
dependencies_extractor = DependenciesExtractor(schema_discovery, metadata)
dependencies_extractor.find_internal_foreign_keys(table, {})