In [None]:
import os
import pickle
from pathlib import Path

import pandas as pd
from tqdm import tqdm

os.chdir("/home1/wangtianshu/universal-blocker")
data_dirs = [
    d.name
    for d in Path("./data/blocking").iterdir()
    if d.name not in ["songs", "citeseer-dblp"]
]

In [None]:
from transformers import AutoModel
from src.models import SimCSE

device=5
model_name_or_path = "./models/roberta-base/"
simcse = SimCSE(model_name_or_path=model_name_or_path, max_length=256)
simcse = simcse.load_from_checkpoint("results/fit/simcse/gittables/1cwvyg3q/checkpoints/step=1500-AP=0.46677.ckpt")
simcse.eval()
tokenizer = simcse.collate_fn.tokenizer
trained = simcse.model
trained = trained.to(device)

roberta = AutoModel.from_pretrained(model_name_or_path)
roberta.eval()
roberta = roberta.to(device)

In [None]:
import random

import torch
import torch.nn.functional as F
import py_stringmatching as sm
from datasets import Dataset
from src.datamodules.blocking import dict2tuples
from pytorch_lightning.utilities import move_data_to_device

cosine = sm.similarity_measure.cosine.Cosine()
qgram_tokenizer = sm.tokenizer.qgram_tokenizer.QgramTokenizer(qval=5, padding=False)
whitespace_tokenizer = sm.tokenizer.whitespace_tokenizer.WhitespaceTokenizer()

def encode(batch, model):
    collate_fn = getattr(model, "collate_fn", None)

    batch: list[dict] = [dict(zip(batch, t)) for t in zip(*batch.values())]
    batch = [dict2tuples(r, "id") for r in batch]
    texts = [" ".join([t[1] for t in l]) for l in batch]

    batch = move_data_to_device(collate_fn(batch), model.device)
    embeddings = F.normalize(model(batch).detach()).to("cpu").numpy()

    return {
        "text": texts,
        "embeddings": embeddings,
    }

@torch.no_grad()
def find_representation_words(model, s):
    inputs = tokenizer(
        s,
        truncation=True,
        max_length=256,
        return_tensors="pt",
    )
    inputs = move_data_to_device(inputs, model.device)
    outputs = model(**inputs, output_attentions=True, output_hidden_states=True)
    attentions = outputs.attentions
    hidden_states = outputs.hidden_states
    last_hidden_state = outputs.last_hidden_state
    last_attentions = outputs.attentions[-1]
    attention_heads = last_attentions.shape[1]
    weight = torch.zeros(last_attentions.shape[2], device=model.device)
    for i in range(attention_heads):
        weight += last_attentions[0, i].sum(dim=0)
    weight /= attention_heads
    weight = F.softmax(weight, dim=-1)
#     print(weight)
    indices = torch.argsort(weight, descending=True)
    representation_words = tokenizer.convert_ids_to_tokens(inputs.input_ids[:, indices][0].tolist(), skip_special_tokens=True)
    return representation_words[:15]

def check_pair(r1, r2, model):
    t1 = dict2tuples(r1)
    t2 = dict2tuples(r2)
    s1 = " ".join([t[1] for t in t1])
    s2 = " ".join([t[1] for t in t2])
#     print(s1, s2, sep="\n")
    rw1 = find_representation_words(model, s1)
    rw2 = find_representation_words(model, s2)
#     print(rw1, rw2, sep="\n")
#     print()
    tw1 = whitespace_tokenzier.tokenize(s1)
    tw2 = whitespace_tokenzier.tokenize(s2)
    
    score1 = cosine.get_sim_score(rw1, rw2)
    score2 = cosine.get_sim_score(tw1, tw2)
    return score1, score2
    
for data_dir in ["imdb-dbpedia", "movies", "amazon-google", "walmart-amazon_homo", "walmart-amazon_heter"]:
    print(data_dir)
    
    table_paths = sorted(Path(f"./data/blocking/{data_dir}").glob("[1-2]*.csv"))
    dfs = [pd.read_csv(p, index_col="id", low_memory=False) for p in table_paths]
    for i in range(len(dfs)):
        dfs[i] = dfs[i].fillna("")
    
    matches_path = Path(f"./data/blocking/{data_dir}/matches.csv")
    matches = set(
        pd.read_csv(matches_path).itertuples(index=False, name=None)
    )
    for encoder in [trained, roberta]:
        pairs = random.sample(list(matches), 100)
        score1, score2 = 0, 0
        win = 0
        for p in pairs:
            r1 = dfs[0].loc[p[0]]
            r2 = dfs[len(dfs) - 1].loc[p[1]]
            s1, s2 = check_pair(r1, r2, encoder)
            score1 += s1
            score2 += s2
            win += s1 > s2
            
        print(score1, score2, win)