In [None]:
import os
import pickle
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import auc
from tqdm import tqdm

os.chdir("/home1/wangtianshu/universal-blocker")
threshold: float = 0.9
data_dirs = [
    d.name
    for d in Path("./data/blocking").iterdir()
    if d.name not in ["songs", "citeseer-dblp"]
]

In [None]:
import torch
import torch.nn.functional as F

import py_stringmatching as sm
import torch
import torch.nn.functional as F
torch.set_grad_enabled(False)

qgram_tokenizer = sm.tokenizer.qgram_tokenizer.QgramTokenizer(qval=5, padding=False)
whitespace_tokenzier = sm.tokenizer.whitespace_tokenizer.WhitespaceTokenizer()
cosine = sm.similarity_measure.cosine.Cosine()

from src.models import SimCSE
from scipy.spatial import distance

def sparse_similarity(
    s1,
    s2,
    tokenizer,
    similarity,
):
    t1 = tokenizer.tokenize(s1)
    t2 = tokenizer.tokenize(s2)
    return similarity.get_sim_score(t1, t2)

def ngram_similarity(s1, s2):
    return sparse_similarity(s1, s2, tokenizer=qgram_tokenizer, similarity=cosine)

def token_similarity(s1, s2):
    return sparse_similarity(s1, s2, tokenizer=whitespace_tokenzier, similarity=cosine)

def embed(s):
    return F.normalize(model(tokenizer(s, padding="max_length", max_length=256, truncation=True, return_tensors="pt")))

def get_similarity(s1, s2):
    e1 = embed(s1)
    e2 = embed(s2)
    return float(torch.mm(e1, e2.T))

In [None]:
from src.models import SimCSE

model_name_or_path="./models/roberta-base"
model = SimCSE(model_name_or_path=model_name_or_path, max_length=256)
model = model.load_from_checkpoint("results/fit/simcse/gittables/1cwvyg3q/checkpoints/step=1500-AP=0.46677.ckpt")
model.eval()
tokenizer = model.collate_fn.tokenizer

In [None]:
s = "Thu Aug 16 23:53:12 +0000 2018"
tokenizer.tokenize(s, is_split_into_words=True)

In [None]:
import json
from src.datamodules.blocking import dict2tuples

def get_records(p, dfs):
    r1 = json.loads(dfs[0].loc[p[0]].to_json())
    r2 = json.loads(dfs[len(dfs) - 1].loc[p[1]].to_json())
    return r1, r2

def r2s(r):
    return " ".join(t[1].casefold() for t in dict2tuples(r, "id"))

def check(p, dfs):
    r1, r2 = get_records(p, dfs)
    s1 = r2s(r1)
    s2 = r2s(r2)
    print(p)
    print(repr(s1), repr(s2), sep="\n")
    print(get_similarity(s1, s2))
    print(token_similarity(s1, s2))
    print()

K = 20
for data_dir in ["walmart-amazon_homo"]:
    print(data_dir)
    with Path(f"./results/debug/sparse_join/{data_dir}.pickle").open("rb") as f:
        sparse_join_candidates = pickle.load(f)

    with Path(f"./results/debug/simcse/{data_dir}.pickle").open("rb") as f:
        simcse_candidates = pickle.load(f)

    table_paths = sorted(Path(f"./data/blocking/{data_dir}").glob("[1-2]*.csv"))
    dfs = [pd.read_csv(p, index_col="id", low_memory=False) for p in table_paths]
    matches_path = Path(f"./data/blocking/{data_dir}/matches.csv")
    matches = set(pd.read_csv(matches_path).itertuples(index=False, name=None))
    check((706, 16537), dfs)
    for candidates in [simcse_candidates]:
        print("-------------------------------------------------")
        cands = set().union(*[s for s in candidates[:K]])
        for k in range(K):
            print(k)
            for p in candidates[k]:
                if p[0] == 706:
                    check(p, dfs)
                
#         print("+++++++++++++++++++++++++++++++++++++++++++++++++")
#         correct = matches & cands
#         wrong = matches - cands
#         cnt = 10

#         for p in wrong:
#             check(p, dfs)
            
#             cnt -= 1
#             if cnt == 0:
#                 break

In [None]:
check((1046, 10182), dfs)

In [None]:
for data_dir in data_dirs:
    print(data_dir)
    with Path(f"../results/debug/sparse_join/{data_dir}.pickle").open("rb") as f:
        sparse_join_candidates = pickle.load(f)

    with Path(f"../results/debug/simcse/{data_dir}.pickle").open("rb") as f:
        simcse_candidates = pickle.load(f)

    table_paths = sorted(Path(f"../data/blocking/{data_dir}").glob("[1-2]*.csv"))
    dfs = [pd.read_csv(p, index_col="id") for p in table_paths]
    matches_path = Path(f"../data/blocking/{data_dir}/matches.csv")
    matches = set(pd.read_csv(matches_path).itertuples(index=False, name=None))
    
#     fig, ax = plt.subplots()
#     # add axis labels to plot
#     ax.set_title('Precision-Recall Curve')
#     ax.set_ylabel('Precision')
#     ax.set_xlabel('Recall')

    union_candidates = [a.union(b) for a, b in zip(sparse_join_candidates, simcse_candidates)]

    for candidates in [sparse_join_candidates, simcse_candidates, union_candidates]:
        cands = set()
        precisions, recalls = [1], [0]
        tps = []

        for i in range(len(candidates)):
            cands = cands | candidates[i]

            tp = len(cands & matches)
            precision = tp / len(cands)
            recall = tp / len(matches)

            precisions.append(precision)
            recalls.append(recall)

            tps.append(tp)

        k = 0
        for i in range(len(candidates) + 1):
            precision = precisions[i]
            recall = recalls[i]
            k = i
            if recall > threshold:
                break

        average_precision = auc(recalls, precisions)

        metrics = {
            "AP": average_precision,
            "PC": recall,
            "PQ": precision,
            "F1": 2 * (precision * recall) / (precision + recall),
            "K": float(k),
        }
        print(recalls[-1])

        #ax.plot(recalls, precisions, color='purple')
    print()
    #display plot
    #plt.show()