In [1]:
import sys
import os
import glob
import json
import importlib.util
import re
import warnings
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv
from concurrent.futures import ProcessPoolExecutor, as_completed
from collections import defaultdict
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

##### ==========================================#####
##### WORKER CODE (for parallel processing)#####
##### ==========================================#####
WORKER_CODE = """
import re
import numpy as np
import pandas as pd
from scipy.stats import entropy
import warnings

##### Normalization: Type standardization & warning suppression#####

warnings.filterwarnings('ignore', category=pd.errors.DtypeWarning)

SAMPLE_SIZE = 500
SIGNATURE_SIZE = 200

def smart_normalize(series):
    ##### Normalization: Trim, casefold, numeric parsing
    s = series.astype(str).str.strip().str.lower()
    try:
        nums = pd.to_numeric(s, errors='coerce')
        mask = nums.notna()
        if mask.mean() > 0.8: 
            s.loc[mask] = nums.loc[mask].astype(float).astype(str)
    except: pass
    return s

def get_id_score(col_name):
######################### Node Features: ID-like tokens #########################
    col = col_name.lower()
    if col in ['id', 'pk', 'key', 'uuid', 'guid', 'index', '_id']: return 3
    if re.search(r'(_id|id|_key|key|_pk|pk)$', col): return 2
    if re.search(r'(id|key|code|num|no)', col): return 1
    return 0

def get_signature(series, n=200):
    #####  Blocking LSH: MinHash-style signature for fast comparison
    uniques = series.dropna().unique()
    if len(uniques) > n:
        try: uniques = np.random.choice(uniques, n, replace=False)
        except: pass
    return set(hash(str(x)) for x in uniques)

def extract_features(df, cols):
    is_composite = isinstance(cols, tuple)
    col_names = list(cols) if is_composite else [cols]
    n_rows = len(df)
    
##### Robustness : Multiple row samples (Head, Tail, Random)#########################
    limit = SAMPLE_SIZE
    try:
        head = df[col_names].head(limit)
        tail = df[col_names].tail(limit)
        rand = df[col_names].sample(min(n_rows, limit)) if n_rows > limit else df[col_names]
    except: return None 

    if is_composite:
        fn = lambda x: ''.join(x.astype(str))
        vals_head = head.agg(fn, axis=1)
        vals_tail = tail.agg(fn, axis=1)
        vals_rand = rand.agg(fn, axis=1)
        name_str = "_".join(cols).lower()
        col_count = len(cols)
        id_score = max([get_id_score(c) for c in cols])
    else:
        vals_head = smart_normalize(head[cols])
        vals_tail = smart_normalize(tail[cols])
        vals_rand = smart_normalize(rand[cols])
        name_str = cols.lower()
        col_count = 1
        id_score = get_id_score(cols)

##### Stability Check: Key must be stable across all samples ###################
    u_scores = []
    for v in [vals_head, vals_tail, vals_rand]:
        u_scores.append(v.nunique() / len(v) if len(v) > 0 else 0)
        
    if (max(u_scores) - min(u_scores)) > 0.5: return None 

##### Node Features: Entropy, UUID regex, Nulls, Duplicates #########
    vals = vals_rand
    n_uniq = vals.nunique()
    uniqueness = u_scores[2]
    null_rate = df[col_names].isna().mean().max()
    dup_rate = 1.0 - uniqueness
    card_log = np.log1p(n_uniq)
    
    counts = vals.value_counts()
    ent = entropy(counts) if len(counts) > 0 else 0
    
    is_num = 0.0
    if not is_composite:
        try: is_num = pd.to_numeric(vals, errors='coerce').notna().mean()
        except: pass
        
    avg_len = vals.str.len().mean() if len(vals) > 0 else 0.0
    is_uuid = 1.0 if re.search(r'(uuid|guid)', name_str) else 0.0
    is_time = 1.0 if re.search(r'(date|time|created)', name_str) else 0.0

##### 11-DIM FEATURE VECTOR #########################
    feats = [uniqueness, null_rate, dup_rate, card_log, ent,
             is_num, avg_len, float(id_score), is_uuid, is_time, float(col_count)]
             
    return feats, get_signature(vals, SIGNATURE_SIZE), avg_len, n_rows

def process_table_file(filepath):
    try:
##### Optimization: Read once, low_memory=False #########################
        try: df = pd.read_csv(filepath, on_bad_lines='skip', engine='pyarrow')
        except: df = pd.read_csv(filepath, on_bad_lines='skip', low_memory=False)
        
        if df.empty: return []
        
        cols = [c for c in df.columns if not re.search(r'^(desc|note|comment|text)', c, re.I)]
        table = filepath.split('/')[-1].replace('.csv','')
        
        candidates = []
        ingredients = []
        
##### 1. Atomic Columns #########################
        for col in cols:
            res = extract_features(df, col)
            if res:
                feats, sig, avg_len, n_rows = res
                candidates.append({
                    'table': table, 'cols': (col,), 
                    'features': feats, 'signature': sig, 
                    'avg_len': avg_len, 'n_rows': n_rows
                })
                if feats[0] > 0.2 or feats[7] > 0: ingredients.append(col)
        
 ##### Composite Pruning: Use top-m ingredients #########################
        import itertools
        if len(ingredients) > 10: ingredients = ingredients[:10]
        
        for r in range(2, 4):
            for c_cols in itertools.combinations(ingredients, r):
                res = extract_features(df, c_cols)
                if res:
                    feats, _, _, n_rows = res
                    ##### [REQ-7] Composite Pruning: Uniqueness threshold
                    if feats[0] > 0.85:
                        candidates.append({
                            'table': table, 'cols': c_cols, 
                            'features': feats, 'signature': set(), 
                            'avg_len': 0, 'n_rows': n_rows
                        })
        return candidates
    except: return []
"""

CONFIG = {
    'folder': './data', 
    'device': 'cpu',
    'tau': 0.25,          ##### Lowered from 0.40 to catch weaker signals
    'min_overlap': 2,     ##### Lowered from 3 to allow smaller overlaps
    'num_workers': max(1, (os.cpu_count() or 4) - 1),
    'hidden_dim': 64, 'heads': 4, 'dropout': 0.1,
    'lr': 0.005, 'weight_decay': 1e-4, 'epochs': 150,
}

##### ==========================================#########################
##### 2. SEMANTIC & REFINEMENT ##########
##### ==========================================#########################
def get_name_similarity(name1, name2):
    n1, n2 = name1.lower(), name2.lower()
    if n1 == n2: return 1.0
    if n1 + 's' == n2 or n1 == n2 + 's': return 0.9 
    if n1.replace('_id', '') == n2 or n2.replace('_id', '') == n1: return 0.8
    return 0.0

def refine_and_export(candidates, schema, data):
    print("\nPhase 5: Refinement & Exporting...")
    
##### Headless backend ######################
    
    import matplotlib
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    
    edge_index = data.edge_index.cpu().numpy()
    edge_attr = data.edge_attr.cpu().numpy()
    relationships = []
    
    print("  > Analyzing graph edges for Foreign Keys...")
##### Track "Near Misses" for debugging ############
    near_misses = 0
    
    for k in range(edge_index.shape[1]):
        src, dst = edge_index[0, k], edge_index[1, k]
        c_src, c_dst = candidates[src], candidates[dst]
        
        if c_src['table'] == c_dst['table']: continue
        
##### Metrics ########### 
        inc_score = edge_attr[k][0]   ##### How much of Src is in Dst? ###
        name_score = get_name_similarity(c_src['cols'][0], c_dst['table'])
        
        ##### LOGIC: ########
        ##### 1. Strong Structural: 98% overlap (almost perfect subset) -> Trust it
        ##### 2. Hybrid: 50% overlap + Name match (e.g. 'user_id' -> 'users')
        is_strong_structural = (inc_score > 0.98)
        is_semantic_match = (inc_score > 0.50 and name_score > 0.5)
        
        if is_strong_structural or is_semantic_match:
            rel = {
                "Source Table": c_src['table'], 
                "Source Col": list(c_src['cols']),
                "Target Table": c_dst['table'], 
                "Confidence": f"{inc_score:.2f}"
            }
            ##### Deduplicate ###########
            if not any(r['Source Table'] == rel['Source Table'] and r['Target Table'] == rel['Target Table'] for r in relationships):
                relationships.append(rel)
        elif inc_score > 0.3:
            near_misses += 1

    if near_misses > 0:
        print(f"  > Note: {near_misses} pairs had some overlap (>30%) but failed validation.")

##### 3. Export CSV REl #################
    if relationships:
        pd.DataFrame(relationships).to_csv("relationships.csv", index=False)
        print(f"  > Saved: relationships.csv ({len(relationships)} relations found)")
    else:
        print("  > No strong FK relationships detected.")

##### Make ERD (Connected Components sample) ######
    try:
        if not relationships:
            print("  > Skipping ERD (No relationships to draw).")
            return schema

        print("  > Drawing ERD (Connected Tables Only)...")
        G_viz = nx.DiGraph()
        
            # ONLY add nodes that exist in a relationship
        for r in relationships:
            G_viz.add_edge(r['Source Table'], r['Target Table'], label="FK")
        
        plt.figure(figsize=(16, 12)) #canvasLarger
        
        #layout that separates clusters
        pos = nx.spring_layout(G_viz, k=0.3, iterations=50, seed=42)
        
        ##drawing with high contrast
        nx.draw_networkx_nodes(G_viz, pos, node_size=2500, node_color='#E0F7FA', edgecolors='#006064')
        nx.draw_networkx_labels(G_viz, pos, font_size=8, font_weight='bold')
        nx.draw_networkx_edges(G_viz, pos, edge_color='#455A64', arrowstyle='-|>', arrowsize=15, width=1.2)
            
        plt.title(f"Inferred Schema ({len(relationships)} Relations)", fontsize=16)
        plt.axis('off')
        plt.tight_layout()
        plt.savefig("erd.png", dpi=300)
        plt.close()
        print("  > Saved: erd.png")
        
    except Exception as e:
        print(f"  > Error drawing ERD: {e}")
    
    return schema

##### ==========================================###############
##### 3. GNN MODEL & GRAPH                          #####
##### ==========================================##########
class SieveGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, heads, dropout):
        super().__init__()
##### Edge Features consumed here (edge_dim=4) ##########
        self.conv1 = GATv2Conv(in_dim, hidden_dim, heads=heads, edge_dim=4, dropout=dropout)
        self.conv2 = GATv2Conv(hidden_dim*heads, 1, heads=1, edge_dim=4, dropout=dropout)

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        x = F.elu(self.conv1(x, edge_index, edge_attr))
        x = self.conv2(x, edge_index, edge_attr)
        return x

def build_graph_fast(candidates):
    print("Phase 2: Building Graph...")
    ##### Blocking / LSH for speed #####
    index = defaultdict(list)
    for idx, c in enumerate(candidates):
        if not c['signature']: continue
        for token in c['signature']: index[token].append(idx)
            
    pair_counts = defaultdict(int)
    for token, indices in index.items():
        if len(indices) < 2 or len(indices) > 50: continue
        for i in range(len(indices)):
            for j in range(i+1, len(indices)):
                idx_a, idx_b = indices[i], indices[j]
                if candidates[idx_a]['table'] == candidates[idx_b]['table']: continue
                pair = tuple(sorted((idx_a, idx_b)))
                pair_counts[pair] += 1
                
    edge_index, edge_attr = [], []
    for (i, j), overlap in pair_counts.items():
        if overlap < CONFIG['min_overlap']: continue
        c_i, c_j = candidates[i], candidates[j]
        len_i, len_j = len(c_i['signature']), len(c_j['signature'])
        
        ## Directional Inclusion metrics
        inc_i_j = overlap / len_i if len_i else 0
        inc_j_i = overlap / len_j if len_j else 0
        ####Alternative Similarity: Jaccard
        jaccard = overlap / (len_i + len_j - overlap)
        ### Edge Features: Value-Length Compatibility
        li, lj = c_i['avg_len'], c_j['avg_len']
        len_ratio = min(li, lj) / (max(li, lj) + 0.01)
        
        if max(inc_i_j, inc_j_i) > CONFIG['tau']:
            edge_index.append([i, j]); edge_attr.append([inc_i_j, inc_j_i, jaccard, len_ratio])
            edge_index.append([j, i]); edge_attr.append([inc_j_i, inc_i_j, jaccard, len_ratio])

    x = torch.tensor([c['features'] for c in candidates], dtype=torch.float).nan_to_num()
    
    if not edge_index:
        edge_index = torch.tensor([[0], [0]], dtype=torch.long)
        edge_attr = torch.zeros((1, 4), dtype=torch.float)
    else:
        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

def get_weak_labels(candidates):
    ##-- Weak/Heuristic Labels
    labels, mask = [], []
    for c in candidates:
        u = c['features'][0]
        is_composite = c['features'][10] > 1
        if is_composite and u > 0.99: labels.append(1.0); mask.append(True)
        elif u < 0.90: labels.append(0.0); mask.append(True)
        else: labels.append(0.0); mask.append(False)
    return torch.tensor(labels).float(), torch.tensor(mask).bool()

def train_model(model, data, labels, mask, cfg):
    ### Optimize with NLL, Report Acc/Prec/Recall ##### ##########
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg['lr'], weight_decay=cfg['weight_decay'])
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([3.0]))
    model.train()
    print(f"\nPhase 3: GNN Training ({cfg['epochs']} Epochs)...")
    print(f"{'Epoch':<6} | {'Loss':<8} | {'Acc':<6} | {'Prec':<6} | {'Rec':<6} | {'F1':<6}")
    print("-" * 55)
    
    for epoch in range(cfg['epochs']):
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out[mask].squeeze(), labels[mask])
        loss.backward()
        optimizer.step()
        
        if epoch % 25 == 0 or epoch == cfg['epochs'] - 1:
            preds = (torch.sigmoid(out[mask]).squeeze() > 0.5).float()
            y_true, y_pred = labels[mask].cpu().numpy(), preds.cpu().numpy()
            acc = accuracy_score(y_true, y_pred)
            prec = precision_score(y_true, y_pred, zero_division=0)
            rec = recall_score(y_true, y_pred, zero_division=0)
            f1 = f1_score(y_true, y_pred, zero_division=0)
            print(f"{epoch:<6} | {loss.item():.4f}   | {acc:.2f}   | {prec:.2f}   | {rec:.2f}   | {f1:.2f}")
    print("-" * 55)

def generate_reports(candidates, probs):
    print("Phase 4: Generating JSON & Data Quality Profile...")
    schema = {"tables": {}}
    table_groups = defaultdict(list)
    for i, c in enumerate(candidates):
        table_groups[c['table']].append((probs[i], c))

    for table, items in table_groups.items():
        valid = [x for x in items if x[1]['features'][0] > 0.99]
        if not valid: continue
        ##### [REQ-7] Minimality Check: Sort by column length first
        valid.sort(key=lambda x: (-x[0], len(x[1]['cols']), -x[1]['features'][7]))
        best_prob, best_cand = valid[0]
        schema["tables"][table] = [{
            "type": "PK", "columns": list(best_cand['cols']),
            ##### [REQ-11] Explainability
            "explainability": {"confidence": f"{best_prob:.2f}", "uniqueness": f"{best_cand['features'][0]:.1%}"}
        }]

    ##### ##### Data Quality Report##########
    dq = []
    for c in candidates:
        if len(c['cols']) == 1:
            dq.append({
                "Table": c['table'], "Column": c['cols'][0], 
                "Rows": c.get('n_rows', 0), 
                "Uniqueness": f"{c['features'][0]:.1%}", "Nulls": f"{c['features'][1]:.1%}"
            })
    pd.DataFrame(dq).to_csv("data_quality_profile.csv", index=False)
    print("  > Saved: data_quality_profile.csv")
    return schema

##### ==========================================###############
##### 4. MAIN ORCHESTRATOR              ##########
##### ==========================================####################
def main():
    print(f"--- Sieve-GNN: Pipeline Start ({CONFIG['num_workers']} Workers) ---")
    files = glob.glob(os.path.join(CONFIG['folder'], "*.csv"))
    if not files: print("No CSVs found."); return
    
    ##### 1. SETTING UP OUR WORKERS ##########
    
    with open("sieve_workers.py", "w") as f: f.write(WORKER_CODE)
    spec = importlib.util.spec_from_file_location("sieve_workers", "sieve_workers.py")
    workers = importlib.util.module_from_spec(spec)
    sys.modules["sieve_workers"] = workers
    spec.loader.exec_module(workers)
    
    print(f"Phase 1: Profiling {len(files)} tables...")
    all_candidates = []
    
    with ProcessPoolExecutor(max_workers=CONFIG['num_workers']) as executor:
        futures = [executor.submit(workers.process_table_file, f) for f in files]
        for i, fut in enumerate(as_completed(futures)):
            res = fut.result()
            if res: all_candidates.extend(res)
            if i % 50 == 0: print(f"  Processed {i}/{len(files)}...", end="\r")
            
    print(f"\n  > Candidates: {len(all_candidates)}")
    if len(all_candidates) < 2: return

    ### 2. GRAPH & LEARN #####
    data = build_graph_fast(all_candidates)
    labels, mask = get_weak_labels(all_candidates)
    
    if mask.sum() > 0:
        model = SieveGNN(11, CONFIG['hidden_dim'], CONFIG['heads'], CONFIG['dropout'])
        train_model(model, data, labels, mask, CONFIG)
        
        ##### 3. PREDICT & REFINE #####
        model.eval()
        with torch.no_grad():
            probs = torch.sigmoid(model(data)).squeeze().numpy()
        
        initial_schema = generate_reports(all_candidates, probs)
        final_schema = refine_and_export(all_candidates, initial_schema, data)
        
        with open("schema_final.json", "w") as f: json.dump(final_schema, f, indent=2)
        print("  > Saved: schema_final.json")
    else:
        print("Error: No labels generated.")

##ALLAH BHORSHA
if __name__ == "__main__":
    import multiprocessing
    multiprocessing.freeze_support()
    main()

  from .autonotebook import tqdm as notebook_tqdm


--- Sieve-GNN: Pipeline Start (11 Workers) ---
Phase 1: Profiling 4431 tables...
  Processed 4400/4431...
  > Candidates: 341810
Phase 2: Building Graph...

Phase 3: GNN Training (150 Epochs)...
Epoch  | Loss     | Acc    | Prec   | Rec    | F1    
-------------------------------------------------------
0      | 0.5981   | 0.83   | 0.91   | 0.90   | 0.90
25     | 0.2269   | 0.99   | 0.98   | 1.00   | 0.99
50     | 0.2001   | 0.99   | 0.99   | 1.00   | 0.99
75     | 0.1756   | 0.99   | 0.99   | 1.00   | 0.99
100    | 0.1558   | 0.99   | 0.99   | 1.00   | 0.99
125    | 0.1410   | 0.99   | 0.99   | 1.00   | 0.99
149    | 0.1319   | 0.99   | 0.99   | 1.00   | 0.99
-------------------------------------------------------
Phase 4: Generating JSON & Data Quality Profile...
  > Saved: data_quality_profile.csv

Phase 5: Refinement & Exporting...
  > Analyzing graph edges for Foreign Keys...
  > Note: 246 pairs had some overlap (>30%) but failed validation.
  > Saved: relationships.csv (13 relati