In [None]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
import torch.nn as nn
import pandas as pd
from tqdm import tqdm

if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/LaBSE')
model = AutoModel.from_pretrained('sentence-transformers/LaBSE')
model.to(device)

def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def get_embeds(sentences):
    sentences = [str(s) if pd.notnull(s) else "" for s in sentences]
    
    encoded_input = tokenizer(
        sentences, padding=True, truncation=True, return_tensors="pt"
    )
    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}

    with torch.no_grad():
        model_output = model(**encoded_input)

    sentence_embeddings = mean_pooling(model_output, encoded_input["attention_mask"])
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    return sentence_embeddings

def get_similarity(embed_1, embed_2):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    return cos(embed_1, embed_2)[0].item()

def get_sim_score(data):
    cols = data.columns
    sim_scores = {lang: [] for lang in cols[1:]}
    for i in tqdm(range(len(data))):
        try:
            embeds = get_embeds(list(data.iloc[i][cols[:]]))
            for j, key in enumerate(sim_scores.keys()):
                sim_scores[key].append(get_similarity(embeds[[0]], embeds[[j+1]]))
        except Exception as e:
            print(f"Error on row {i}: {e}")
            continue
    return sim_scores

data = pd.read_csv("../data/educhat-translation/all_translations.csv")
sim_score = pd.DataFrame(get_sim_score(data)).mean().reset_index()
sim_score.columns = ["language", "score"]

sim_score.to_json("../results/refer-less.json", orient="records", indent=4)