In [1]:
%cd ..

/home/soda/rcappuzz/work/benchmark-join-suggestions


In [24]:
%load_ext autoreload
%autoreload 2

In [19]:
import ast
import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import polars as pl
import xgboost
from catboost import CatBoostError, CatBoostRegressor
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from src.table_integration.utils_joins import execute_join

In [20]:
import src.evaluation.evaluation_methods as em

In [21]:
data_dir = Path("data/source_tables/ken_datasets/the-movies-dataset/")
final_table_path = Path(data_dir, "movies-prepared.parquet")

mse_dict = {}

# Train on single table

In [22]:
# Some of the numerical columns may be better left categorical (e.g. id, year )
num_features = [
    "budget", 
    "popularity",
    "release_date", 
    "runtime",
    "vote_average",
    "vote_count",
    "target"
]

df, num_features, cat_features = em.prepare_table_for_evaluation(final_table_path, num_features)

budget
popularity
release_date
runtime
vote_average
vote_count
target


In [25]:
results = em.run_on_table(df, num_features, cat_features, verbose=0)

Index(['adult', 'budget', 'genres', 'id', 'original_language',
       'original_title', 'popularity', 'production_companies',
       'production_countries', 'release_date', 'runtime', 'spoken_languages',
       'status', 'title', 'video', 'vote_average', 'vote_count',
       'col_to_embed'],
      dtype='object')


In [10]:
rmse = mean_squared_error(y_true=results["y_test"], y_pred=results["y_pred"], squared=False)
mse_dict["base"] = rmse

# Read join candidates

In [11]:
with open("generated_candidates.pickle", "rb") as fp:
    candidates = pickle.load(fp)

In [14]:
candidates_mh = candidates["minhash"]

In [15]:
result_dict = em.execute_on_candidates(candidates_mh, df, num_features, cat_features)

  0%|          | 0/2 [00:00<?, ?it/s]

/storage/store/work/rcappuzz/yago3-dl/binary/yago_binary_edited.parquet


 50%|█████     | 1/2 [00:10<00:10, 10.81s/it]

/storage/store/work/rcappuzz/yago3-dl/binary/yago_binary_wroteMusicFor.parquet


100%|██████████| 2/2 [00:22<00:00, 11.32s/it]


In [26]:
result_dict

{'97d2126d7e589f70957001cbd745a69b': 0.7502240248982119,
 '0d1a9aba67fabe36679f94826a1d0cf4': 0.6077451831838343}

In [17]:
merged_table = em.execute_full_join(candidates_mh, df, num_features, cat_features)

100%|██████████| 2/2 [00:00<00:00, 33.86it/s]

/storage/store/work/rcappuzz/yago3-dl/binary/yago_binary_edited.parquet
/storage/store/work/rcappuzz/yago3-dl/binary/yago_binary_wroteMusicFor.parquet





In [18]:
merged_table.schema

{'adult': Utf8,
 'budget': Float64,
 'genres': Utf8,
 'id': Utf8,
 'original_language': Utf8,
 'original_title': Utf8,
 'popularity': Float64,
 'production_companies': Utf8,
 'production_countries': Utf8,
 'release_date': Float64,
 'runtime': Float64,
 'spoken_languages': Utf8,
 'status': Utf8,
 'title': Utf8,
 'video': Boolean,
 'vote_average': Float64,
 'vote_count': Float64,
 'col_to_embed': Utf8,
 'target': Float64,
 'subject': Utf8,
 'subject_0d1a9': Utf8}

In [16]:
cat_features = [col for col in merged.columns if col not in num_features]
merged = merged.fill_null("")
results = run_on_table(merged, num_features, cat_features)

Learning rate set to 0.054162
0:	learn: 1.2158183	total: 14.7ms	remaining: 14.7s
1:	learn: 1.1876392	total: 27.2ms	remaining: 13.6s
2:	learn: 1.1621223	total: 40.1ms	remaining: 13.3s
3:	learn: 1.1379671	total: 52.1ms	remaining: 13s
4:	learn: 1.1132685	total: 63.9ms	remaining: 12.7s
5:	learn: 1.0930111	total: 76.4ms	remaining: 12.7s
6:	learn: 1.0750422	total: 88.7ms	remaining: 12.6s
7:	learn: 1.0582371	total: 103ms	remaining: 12.7s
8:	learn: 1.0409984	total: 114ms	remaining: 12.6s
9:	learn: 1.0241662	total: 126ms	remaining: 12.5s
10:	learn: 1.0089821	total: 137ms	remaining: 12.3s
11:	learn: 0.9941231	total: 149ms	remaining: 12.3s
12:	learn: 0.9809476	total: 162ms	remaining: 12.3s
13:	learn: 0.9683222	total: 173ms	remaining: 12.2s
14:	learn: 0.9566261	total: 184ms	remaining: 12.1s
15:	learn: 0.9458189	total: 196ms	remaining: 12.1s
16:	learn: 0.9369783	total: 208ms	remaining: 12s
17:	learn: 0.9277619	total: 221ms	remaining: 12.1s
18:	learn: 0.9187682	total: 233ms	remaining: 12s
19:	learn:

In [19]:
rmse = mean_squared_error(y_true=results["y_test"], y_pred=results["y_pred"], squared=False)
print(rmse)
mse_dict["full_merge"] = rmse

0.550348671303569


In [30]:
for k, v in mse_dict.items():
    if k in candidates_mh:
        print(candidates_mh[k].candidate_metadata["df_name"])
    print(f"{k:<32} {v:>8.2f}")

base                                 0.70
yago_binary_edited
97d2126d7e589f70957001cbd745a69b     0.75
yago_binary_wroteMusicFor
0d1a9aba67fabe36679f94826a1d0cf4     0.61
full_merge                           0.55


In [24]:
print(f"{k:<32} {v:.2f}")

full_merge                       0.55


In [33]:
df = pl.read_parquet(final_table_path)
df.dtypes

[Utf8,
 Utf8,
 Utf8,
 Utf8,
 Utf8,
 Utf8,
 Utf8,
 Utf8,
 Utf8,
 Utf8,
 Float64,
 Utf8,
 Utf8,
 Utf8,
 Boolean,
 Float64,
 Float64,
 Utf8,
 Float64]

In [41]:
for col in df.columns:
    try:
        df = df.with_columns(
            pl.col(col).cast(pl.Float64)
        )
    except pl.ComputeError:
        continue

In [47]:
cat_features = [k for k,v in df.schema.items() if str(v) == "Utf8"]

In [48]:
cat_features

['adult',
 'genres',
 'original_language',
 'original_title',
 'production_companies',
 'production_countries',
 'spoken_languages',
 'status',
 'title',
 'col_to_embed']

In [31]:
l =[("a", "b", "c"), ("a", "b", "c"), ("a", "b", "c"), ("a", "b", "c") ]

In [33]:
pl.from_records(l, orient="row")

column_0,column_1,column_2
str,str,str
"""a""","""b""","""c"""
"""a""","""b""","""c"""
"""a""","""b""","""c"""
"""a""","""b""","""c"""
