In [None]:
# # !pip install annoy
# import os
# # print(os.environ["LD_LIBRARY_PATH"])
# os.environ["LD_LIBRARY_PATH"] = "/opt/conda/lib/python3.8/site-packages/torch/lib:/usr/local/cuda-11.3/lib64"

In [None]:
# clean text
# from textblob import TextBlob
import re
import string


def decontracted(phrase):

    # Specific
    phrase = re.sub(r"won't", "will not", phrase)
    phrase = re.sub(r"can\'t", "can not", phrase)
    # ..

    # General
    phrase = re.sub(r"n\'t", " not", phrase)
    phrase = re.sub(r"\'re", " are", phrase)
    phrase = re.sub(r"\'s", " is", phrase)
    phrase = re.sub(r"\'d", " would", phrase)
    phrase = re.sub(r"\'ll", " will", phrase)
    phrase = re.sub(r"\'t", " not", phrase)
    phrase = re.sub(r"\'ve", " have", phrase)
    phrase = re.sub(r"\'m", " am", phrase)
    # ..

    return phrase

def remove_punctuations(text):
    for punctuation in list(string.punctuation): text = text.replace(punctuation, '')
    return text

def clean_number(text):
    text = re.sub(r'(\d+)([a-zA-Z])', '\g<1> \g<2>', text)
    text = re.sub(r'(\d+) (th|st|nd|rd) ', '\g<1>\g<2> ', text)
    text = re.sub(r'(\d+),(\d+)', '\g<1>\g<2>', text)
    return text

def clean_whitespace(text):
    text = text.strip()
    text = re.sub(r"\s+", " ", text)
    return text

def clean_repeat_words(text):
    return re.sub(r"(\w*)(\w)\2(\w*)", r"\1\2\3", text)

def clean_text(text):
    # text_blob = TextBlob(text)
    # text = str(text_blob.correct())
    text = str(text)
    text = decontracted(text)
    text = remove_punctuations(text)
    text = clean_number(text)
    text = clean_whitespace(text)
    
    return text

In [None]:
import torch
import pandas as pd
from dataset import AutoTokenizer, LANGUAGE_TOKENS, CATEGORY_TOKENS, LEVEL_TOKENS, KIND_TOKENS, OTHER_TOKENS
from model import Model

from torch.utils.data import DataLoader, Dataset, default_collate

In [None]:
from pathlib import Path


TEST_MODE = False

# --------------------- VALIDATION SET --------------------------
from tqdm import tqdm
if not TEST_MODE:
    data_df = pd.read_csv("./data/supervised_correlations.csv")
    fold = 0
val_topic_ids = list(data_df[data_df["fold"] == fold].topics_ids)
del data_df

data_folder = Path("./data")
# TODO: we have to process for test set ourselves
contents_df = pd.read_csv(data_folder/'content.csv')
contents_df = contents_df.fillna('')
contents_df['title_len'] = contents_df.title.str.len()
contents_df = contents_df.sort_values(by='title_len', axis=0).reset_index(drop=True).drop(columns=['title_len'])
topics_df = pd.read_csv(data_folder/'topics.csv')
topics_df = topics_df.fillna('')
topics_df['title_len'] = topics_df.title.str.len()
topics_df = topics_df.sort_values(by='title_len', axis=0).reset_index(drop=True).drop(columns=['title_len'])
subs_df = pd.read_csv(data_folder/'sample_submission.csv')
corrs_df = pd.read_csv(data_folder/'correlations.csv')


topics_df["title"] = topics_df["title"].apply(clean_text)
topics_df["description"] = topics_df["description"].apply(clean_text)

contents_df["title"] = contents_df["title"].apply(clean_text)
contents_df["description"] = contents_df["description"].apply(clean_text)
# contents_df["text"] = contents_df["text"].apply(clean_text)

In [None]:
# supervised_correlations = pd.read_csv("data/supervised_correlations.csv")

In [None]:
# supervised_correlations[(supervised_correlations["topics_ids"].isin(val_topic_ids))]

In [None]:
# supervised_correlations[(supervised_correlations["topics_ids"].isin(val_topic_ids)) & (supervised_correlations["target"] == 1)]

In [None]:
from tqdm import tqdm

if TEST_MODE:
    topics_df = topics_df[topics_df.id.isin(subs_df.topic_id)]
else: # VAL_MODE
    topics_df = topics_df[topics_df.id.isin(val_topic_ids)]

topic_dict = {}
for i, (index, row) in tqdm(enumerate(topics_df.iterrows())):
    text = "<|topic|>" + f"<|lang_{row['language']}|>" + f"<|category_{row['category']}|>" + f"<|level_{row['level']}|>"
    text += "<s_title>" + row["title"] + "</s_title>" + "<s_description>" + row["description"] + "</s_description>"
    topic_dict[row["id"]] = text

content_dict = {}
for i, (index, row) in tqdm(enumerate(contents_df.iterrows())):
    text = "<|content|>" + f"<|lang_{row['language']}|>" + f"<|kind_{row['kind']}|>"
    text += "<s_title>" + row["title"] + "</s_title>" + "<s_description>" + row["description"] + "</s_description>" # + "<s_text>" + row["text"] + "</s_text>"
    content_dict[row["id"]] = text

In [None]:
topics_df["topic_text"] = topic_dict.values()
topics_df["topic_text"] = topics_df["topic_text"] # .apply(lambda x: x[:2048])

contents_df["content_text"] = content_dict.values()
contents_df["content_text"] = contents_df["content_text"] # .apply(lambda x: x[:2048])

In [None]:
class InferenceDataset(Dataset):
    def __init__(self, texts, tokenizer_name='xlm-roberta-base', max_len=512):
        self.texts = texts

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.tokenizer.add_special_tokens(dict(additional_special_tokens=LANGUAGE_TOKENS + CATEGORY_TOKENS + LEVEL_TOKENS + KIND_TOKENS + OTHER_TOKENS))
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        
        # topic
        inputs = self.tokenizer.encode_plus(
            text, 
            return_tensors = None, 
            add_special_tokens = True, 
            max_length = self.max_len,
            padding='max_length',
            truncation = True
        )
        for k, v in inputs.items():
            inputs[k] = torch.tensor(v, dtype = torch.long)
            
        return inputs
    
def collate_fn(inputs):
    inputs = default_collate(inputs)
    mask_len = int(inputs["attention_mask"].sum(axis=1).max())
    for k, v in inputs.items():
        inputs[k] = inputs[k][:,:mask_len]
        
    return inputs

In [None]:
topic_dataset = InferenceDataset(texts=list(topics_df.topic_text.values), tokenizer_name='sentence-transformers/all-MiniLM-L6-v2', max_len=128)
topic_dataloader = DataLoader(topic_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

In [None]:

model = Model(tokenizer_name="sentence-transformers/all-MiniLM-L6-v2", model_name="sentence-transformers/all-MiniLM-L6-v2", objective="both")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)


weights_path = "/home/jovyan/lecr/outputs_siamese/checkpoint-81549/pytorch_model.bin"

state_dict = torch.load(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("model."):
        new_k = k[6:]
        new_state_dict[new_k] = v

model.model.load_state_dict(new_state_dict)

if "fc.weight" in state_dict:
    model.fc.load_state_dict({
        "weight": state_dict["fc.weight"],
        "bias": state_dict["fc.bias"]
    })


In [None]:
import torch.nn.functional as F

topic_embs = []

for inputs in tqdm(topic_dataloader):
    for k, v in inputs.items():
        inputs[k] = inputs[k].to(device)
    out = model.feature(inputs)
    topic_embs.extend(out.cpu().detach().numpy())

In [None]:
content_dataset = InferenceDataset(texts=list(contents_df.content_text.values), tokenizer_name='sentence-transformers/all-MiniLM-L6-v2', max_len=128)
content_dataloader = DataLoader(content_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn)

# # 
# del contents_df["text"]
# del contents_df

# import gc
# gc.collect()

content_embs = []

for inputs in tqdm(content_dataloader):
    for k, v in inputs.items():
        inputs[k] = inputs[k].to(device)
    out = model.feature(inputs)
    content_embs.extend(out.cpu().detach().numpy())

In [None]:
# load from saved files
torch.save(topic_embs, "./data/topic_embs.pt")
torch.save(content_embs, "./data/content_embs.pt")

# topic_embs = torch.load("./data/topic_embs.pt")
# content_embs = torch.load("./data/content_embs.pt")

In [None]:
# !pip install fuzzywuzzy annoy

In [None]:
# normalized_content_embs = []
# for emb in tqdm(content_embs):
#     normalized_content_embs.append(F.normalize(torch.from_numpy(emb), p=2, dim=0).numpy())

In [None]:
from fuzzywuzzy import fuzz, process

from annoy import AnnoyIndex


content_forest = AnnoyIndex(content_embs[0].shape[0], metric='angular')
for i, item in tqdm(enumerate(content_embs), total=len(content_embs)):
    content_forest.add_item(i, item)
content_forest.build(200)

In [None]:
# topics = topics_df[topics_df.has_content==True][['id', 'title', 'language']].reset_index(drop=True)

topics = topics_df

test = topics
all_content_ids = contents_df.id.to_numpy()
all_content_titles = contents_df.title.to_numpy()
all_content_language = contents_df.language.to_numpy()
all_test_ids = list(topics.id)
all_test_title = list(topics.title)
all_test_language = list(test.language)

In [None]:
# content_forest.get_nns_by_vector(topic_embs[5], 10, include_distances=True)

In [None]:
indexes_dict = {}
fuzzy_dict = {}
classification_dict = {}

In [None]:
# TODO: find by distance instead

nearest_content_count = 500
fuzzy_filter = 80
THRESHOLD = 0
# for fuzzy_filter in range(5, 50, 5):
#     for t in range(1, 10, 2):
#         THRESHOLD = t / 100
preds = []
for i, t_e in tqdm(enumerate(topic_embs), total=len(topic_embs), desc=f'Getting Preds'):
    if i in indexes_dict:
        indexes, distances = indexes_dict[i]
    else:
        indexes, distances = content_forest.get_nns_by_vector(
            # F.normalize(torch.from_numpy(t_e), p=2, dim=0),
            t_e,
            nearest_content_count,
            include_distances=True
        )
        # indexes = [i for i, d in zip(indexes, distances) if d < 10]
        indexes_dict[i] = indexes, distances

    topic_id = all_test_ids[i]
    topic_text = all_test_title[i]
    topic_lang = all_test_language[i]

    # for idx in indexes:
    #     if topic_lang != all_content_language[idx]:
    #         indexes.remove(idx)
    
    # filtered_indexes = []
    # for idx in indexes:
    #     if topic_lang != all_content_language[idx]:
    #         continue
    #     if (i, idx) in fuzzy_dict:
    #         fuzzy_value = fuzzy_dict[(i, idx)]
    #     else:
    #         fuzzy_value = fuzz.token_set_ratio(all_content_titles[idx], topic_text)
    #         fuzzy_dict[(i, idx)] = fuzzy_value
        
    #     if fuzzy_value > fuzzy_filter:
    #         filtered_indexes.append(idx)
        
    #     if (i, idx) in classification_dict:
    #         score = classification_dict[(i, idx)]
    #     else:
    #         topic_features = torch.from_numpy(t_e).to(device)
    #         content_features = torch.from_numpy(content_embs[idx]).to(device)
    #         score = torch.sigmoid(model.fc(torch.cat([topic_features, content_features, topic_features - content_features], -1))).item()
    #         classification_dict[(i, idx)] = score
    #     if score < THRESHOLD and idx in filtered_indexes:
    #         filtered_indexes.remove(idx)
    # ind2dis = {ind: d for ind, d in zip(indexes, distances)}
    # if len(filtered_indexes) == 0:
    #     indexes = filtered_indexes[:8] # list(set(filtered_indexes + indexes[:8-len(filtered_indexes)]))
    # else:
    #     indexes = filtered_indexes[:8]
    content_ids = all_content_ids[indexes]
    preds.append({
        'topic_id': topic_id,
        'content_ids': ' '.join(content_ids),
        # 'distances': ' '.join([str(ind2dis[ind]) for ind in indexes]),
    })
preds = pd.DataFrame.from_records(preds)

preds.to_csv('submission.csv', index=False)

if not TEST_MODE:
    from engine import f2_score
    gt = corrs_df[corrs_df.topic_id.isin(val_topic_ids)].sort_values("topic_id")    
    preds = preds.sort_values("topic_id")    
    print("fuzzy_filter", fuzzy_filter, "THRESHOLD:", THRESHOLD, "f2_score", f2_score(gt["content_ids"], preds["content_ids"]))

In [None]:
import numpy as np
def get_pos_score(y_true, y_pred):
    y_true = y_true.apply(lambda x: set(x.split()))
    y_pred = y_pred.apply(lambda x: set(x.split()))
    int_true = np.array([len(x[0] & x[1]) / len(x[0]) for x in zip(y_true, y_pred)])
    return round(np.mean(int_true), 5)

In [None]:
get_pos_score(gt["content_ids"], preds["content_ids"])

In [None]:
# from torch.nn.functional import cosine_similarity

In [None]:
# # gt sims
# list_all_content_ids = list(all_content_ids)

# all_sims = []
# for i, row in tqdm(gt.iterrows()):
#     sims = []

#     t_e = topic_embs[all_test_ids.index(row["topic_id"])]
#     for content_id in row["content_ids"].split(" "):
#         c_e = content_embs[list_all_content_ids.index(content_id)]
#         sims.append(cosine_similarity(torch.from_numpy(t_e), torch.from_numpy(c_e), 0))
    
#     all_sims.append(" ".join([str(s.item())[:5] for s in sims]))

In [None]:
# all_sims

In [None]:
# # prediction sims

# list_all_content_ids = list(all_content_ids)

# all_preds_sims = []
# for i, row in tqdm(preds.iterrows()):
#     sims = []

#     t_e = topic_embs[all_test_ids.index(row["topic_id"])]
#     if not row["content_ids"]:
#         continue
#     for content_id in row["content_ids"].split(" "):
#         c_e = content_embs[list_all_content_ids.index(content_id)]
#         sims.append(cosine_similarity(torch.from_numpy(t_e), torch.from_numpy(c_e), 0))
    
#     all_preds_sims.append(" ".join([str(s.item())[:5] for s in sims]))

In [None]:
# all_preds_sims

In [None]:
# import numpy as np

# sample_preds = pd.DataFrame({
#     "content_ids": [
#         "a b c"
#     ]
# })

# sample_gts = pd.DataFrame({
#     "content_ids": [
#         "a d e f g h"
#     ]
# })
# f2_score(
#     sample_gts["content_ids"],
#     sample_preds["content_ids"],
# )