In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
%cd ..

/home/soda/rcappuzz/work/benchmark-join-suggestions


In [4]:
import json
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import sklearn.metrics as metrics
from joblib import Parallel, delayed, dump, load
from tqdm import tqdm
import seaborn as sns
import polars.selectors as cs

import src.methods.profiling as jp
from src.data_structures.join_discovery_methods import MinHashIndex, ExactMatchingIndex

In [5]:
cfg = pl.Config()
cfg.set_fmt_str_lengths(30)

polars.config.Config

In [6]:
data_lake_case = "binary_update"


In [23]:
overlap_dict = {}
# for table_name in ["us-accidents-yadl"]:
for table_name in [
    "company-employees-yadl",
    "movies-yadl",
    "us-accidents-yadl",
    "us-presidential-results-yadl",
]:
    mdata_root_dir = Path("data/metadata/")
    mdata_path = Path(mdata_root_dir, data_lake_case)

    base_table_path = f"data/source_tables/{table_name}.parquet"
    df_base = pl.read_parquet(base_table_path)
    query_column = "col_to_embed"
    exact_matching_index = ExactMatchingIndex(mdata_path, base_table_path, query_column, n_jobs=-1)
    overlap_dict[table_name] = exact_matching_index.counts

res_list = []
for tab_name, tab in overlap_dict.items():
    tab = tab.with_columns(pl.lit(tab_name).alias("table_name"))
    res_list.append(tab)
df_exact = pl.concat(res_list)

                                                

In [24]:
exact_matching_index.counts

hash,col,containment
str,str,f64
"""6718868bd0dc6d153cbaeec8e4148…","""isLocatedIn""",0.752617
"""cd13e0a2f80871e36479eb5396e38…","""subject""",0.749762
"""a6d3064f36381a789d31e8ad9da96…","""subject""",0.749762
"""d5fab84d7f1f089ec5d305d34a2da…","""subject""",0.736441
"""d9255a643aa4106cbfad9c890cd9c…","""wasBornIn""",0.46876
"""7b879cbeede3c94cc8977f725723c…","""subject""",0.356486
"""6718868bd0dc6d153cbaeec8e4148…","""subject""",0.231526
"""46e178e2436876037c603dd916f0f…","""diedIn""",0.227402
"""8c129d261ce6359dabeaca1f5b392…","""subject""",0.129083
"""e55c5fbe7bd369931ae1a9b1e1391…","""subject""",0.113225


In [25]:
def minhash_matching(index, query_column_values):
    # Querying index for any candidates
    query_result = index.query_index(query_column_values)
    return query_result

In [26]:
query_column = "col_to_embed"
result_list = []
index_path = Path("data/metadata/_indices/binary_update/minhash_index.pickle")
index = MinHashIndex()
with open(index_path, "rb") as fp:
    input_dict = load(fp)
    index.load_index(index_dict=input_dict)

for table_name in [
    "company-employees-yadl",
    "movies-yadl",
    "us-accidents-yadl",
    "us-presidential-results-yadl",
]:
    base_table_path = f"data/source_tables/{table_name}.parquet"
    df_base = pl.read_parquet(base_table_path)
    query_column_values = df_base[query_column].to_list()

    query_result = minhash_matching(index, query_column_values)
    # Preparing the same dataframe as before for prediction
    ll = [[row[i] for row in query_result] for i in range(3)]
    _pred = pl.from_dict(
        dict(zip(["hash", "col", "score"], ll)),
        schema=[("hash", str), ("col", str), ("score", float)],
    )
    _pred = _pred.with_columns(
        pl.lit(index_path.stem).alias("index_case"),
        pl.lit(table_name).alias("table_name"),
    )
    result_list.append(_pred)
df_pred = pl.concat(result_list)

In [27]:
df_pred

hash,col,score,index_case,table_name
str,str,f64,str,str
"""214f45c3af7758a15bba6cbe0f42a…","""subject""",20.0,"""minhash_index""","""company-employees-yadl"""
"""3e053293a3daadc9ad622628632ab…","""worksAt""",20.0,"""minhash_index""","""company-employees-yadl"""
"""579b951225c94f8ac4fd05ee1c4d5…","""graduatedFrom""",20.0,"""minhash_index""","""company-employees-yadl"""
"""74ca3153e0861b2f84924018b92f5…","""happenedIn""",20.0,"""minhash_index""","""company-employees-yadl"""
"""e55c5fbe7bd369931ae1a9b1e1391…","""subject""",20.0,"""minhash_index""","""company-employees-yadl"""
"""91258ecfb767792f6de9c60b2f278…","""playsFor""",20.0,"""minhash_index""","""company-employees-yadl"""
"""7e10ea986d675fa4f5a0a08e6e919…","""edited""",20.0,"""minhash_index""","""movies-yadl"""
"""47b113eb748994355dd72b94a0041…","""subject""",20.0,"""minhash_index""","""movies-yadl"""
"""b987248777c65dc2a354f2c428aa6…","""wroteMusicFor""",20.0,"""minhash_index""","""movies-yadl"""
"""3811d6cb72281b1cbe7ebd11950d4…","""directed""",20.0,"""minhash_index""","""movies-yadl"""


In [28]:
def get_stat_df(df_exact, df_pred):
    stats = []
    for idx, group in df_pred.group_by(["table_name", "index_case"]):
        print(idx)
        idx_case = 20
        df_group = (
            df_exact
            .filter(pl.col("table_name") == idx[0])
            .with_columns(
                (pl.lit(idx_case).alias("thr"))
            )
            .join(group, on=["table_name", "hash", "col"], how="left")
            .with_columns(
                pl
                .when(pl.col("containment") >= pl.col("thr") / 100)
                .then(1)
                .when(pl.col("score").is_null())
                .then(0)
                .otherwise(0)
                .alias("class")
            )
            .with_columns(
                pl.when(
                    pl.col("score").is_null(),
                )
                .then(0)
                .otherwise(1)
                .alias("class_pred")
            )
        )

        c_df = df_group.select(pl.col("class"), pl.col("class_pred")).to_pandas()
        
        conf_m = metrics.confusion_matrix(c_df["class"], c_df["class_pred"])
        tn, fp, fn, tp = conf_m.ravel()

        recall = metrics.recall_score(c_df["class"], c_df["class_pred"])
        precision = metrics.precision_score(c_df["class"], c_df["class_pred"])
        f1 = metrics.f1_score(c_df["class"], c_df["class_pred"])
        # Prepare a simplified df
        rd = {
            "table_name": idx[0],
            "case": idx[1],
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "tp": tp,
            "tn": tn,
            "fp": fp,
            "fn": fn,
        }
        
        stats.append(rd)
    df_stats = pl.from_dicts(stats)
    return df_stats

In [29]:
get_stat_df(df_exact, df_pred)

('us-accidents-yadl', 'minhash_index')
('movies-yadl', 'minhash_index')
('company-employees-yadl', 'minhash_index')
('us-presidential-results-yadl', 'minhash_index')


table_name,case,precision,recall,f1,tp,tn,fp,fn
str,str,f64,f64,f64,i64,i64,i64,i64
"""us-accidents-yadl""","""minhash_index""",0.625,0.454545,0.526316,5,105,3,6
"""movies-yadl""","""minhash_index""",0.875,1.0,0.933333,7,111,1,0
"""company-employees-yadl""","""minhash_index""",0.333333,0.333333,0.333333,2,109,4,4
"""us-presidential-results-yadl""","""minhash_index""",0.5,0.5,0.5,4,107,4,4
