In [None]:
# !pip install annoy

In [None]:
# clean text
from textblob import TextBlob
import re
import string


def decontracted(phrase):

    # Specific
    phrase = re.sub(r"won't", "will not", phrase)
    phrase = re.sub(r"can\'t", "can not", phrase)
    # ..

    # General
    phrase = re.sub(r"n\'t", " not", phrase)
    phrase = re.sub(r"\'re", " are", phrase)
    phrase = re.sub(r"\'s", " is", phrase)
    phrase = re.sub(r"\'d", " would", phrase)
    phrase = re.sub(r"\'ll", " will", phrase)
    phrase = re.sub(r"\'t", " not", phrase)
    phrase = re.sub(r"\'ve", " have", phrase)
    phrase = re.sub(r"\'m", " am", phrase)
    # ..

    return phrase

def remove_punctuations(text):
    for punctuation in list(string.punctuation): text = text.replace(punctuation, '')
    return text

def clean_number(text):
    text = re.sub(r'(\d+)([a-zA-Z])', '\g<1> \g<2>', text)
    text = re.sub(r'(\d+) (th|st|nd|rd) ', '\g<1>\g<2> ', text)
    text = re.sub(r'(\d+),(\d+)', '\g<1>\g<2>', text)
    return text

def clean_whitespace(text):
    text = text.strip()
    text = re.sub(r"\s+", " ", text)
    return text

def clean_repeat_words(text):
    return re.sub(r"(\w*)(\w)\2(\w*)", r"\1\2\3", text)

def clean_text(text):
    # text_blob = TextBlob(text)
    # text = str(text_blob.correct())
    text = str(text)
    text = decontracted(text)
    text = remove_punctuations(text)
    text = clean_number(text)
    text = clean_whitespace(text)
    
    return text

In [None]:
import torch
import pandas as pd
from dataset import AutoTokenizer, LANGUAGE_TOKENS, CATEGORY_TOKENS, LEVEL_TOKENS, KIND_TOKENS, OTHER_TOKENS
from model import Model

from torch.utils.data import DataLoader, Dataset, default_collate

In [None]:
from pathlib import Path


TEST_MODE = False

# --------------------- VALIDATION SET --------------------------
from tqdm import tqdm
if not TEST_MODE:
    data_df = pd.read_csv("./data/siamese_train.csv")
    fold = 0

data_folder = Path("./data")
# TODO: we have to process for test set ourselves
contents_df = pd.read_csv(data_folder/'content.csv')
contents_df = contents_df.fillna('')
contents_df['title_len'] = contents_df.title.str.len()
contents_df = contents_df.sort_values(by='title_len', axis=0).reset_index(drop=True).drop(columns=['title_len'])
topics_df = pd.read_csv(data_folder/'topics.csv')
topics_df = topics_df.fillna('')
topics_df['title_len'] = topics_df.title.str.len()
topics_df = topics_df.sort_values(by='title_len', axis=0).reset_index(drop=True).drop(columns=['title_len'])
subs_df = pd.read_csv(data_folder/'sample_submission.csv')
corrs_df = pd.read_csv(data_folder/'correlations.csv')


topics_df["title"] = topics_df["title"].apply(clean_text)
topics_df["description"] = topics_df["description"].apply(clean_text)

contents_df["title"] = contents_df["title"].apply(clean_text)
contents_df["description"] = contents_df["description"].apply(clean_text)
contents_df["text"] = contents_df["text"].apply(clean_text)

In [None]:
from tqdm import tqdm

if TEST_MODE:
    topics_df = topics_df[topics_df.id.isin(subs_df.topic_id)]
else: # VAL_MODE
    topics_df = topics_df[topics_df.id.isin(data_df[data_df["fold"] == fold].topics_ids)]

topic_dict = {}
for i, (index, row) in tqdm(enumerate(topics_df.iterrows())):
    text = "<|topic|>" + f"<|lang_{row['language']}|>" + f"<|category_{row['category']}|>" + f"<|level_{row['level']}|>"
    text += "<s_title>" + row["title"] + "</s_title>" + "<s_description>" + row["description"] + "</s_description>"
    topic_dict[row["id"]] = text

content_dict = {}
for i, (index, row) in tqdm(enumerate(contents_df.iterrows())):
    text = "<|content|>" + f"<|lang_{row['language']}|>" + f"<|kind_{row['kind']}|>"
    text += "<s_title>" + row["title"] + "</s_title>" + "<s_description>" + row["description"] + "</s_description>" + "<s_text>" + row["text"] + "</s_text>"
    content_dict[row["id"]] = text

In [None]:
topics_df["topic_text"] = topic_dict.values()
topics_df["topic_text"] = topics_df["topic_text"].apply(lambda x: x[:2048])

contents_df["content_text"] = content_dict.values()
contents_df["content_text"] = contents_df["content_text"].apply(lambda x: x[:2048])

In [None]:
class InferenceDataset(Dataset):
    def __init__(self, texts, tokenizer_name='xlm-roberta-base', max_len=512):
        self.texts = texts

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.tokenizer.add_special_tokens(dict(additional_special_tokens=LANGUAGE_TOKENS + CATEGORY_TOKENS + LEVEL_TOKENS + KIND_TOKENS + OTHER_TOKENS))
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        
        # topic
        inputs = self.tokenizer.encode_plus(
            text, 
            return_tensors = None, 
            add_special_tokens = True, 
            max_length = self.max_len,
            pad_to_max_length = True,
            truncation = True
        )
        for k, v in inputs.items():
            inputs[k] = torch.tensor(v, dtype = torch.long)
            
        return inputs
    
def collate_fn(inputs):
    inputs = default_collate(inputs)
    mask_len = int(inputs["attention_mask"].sum(axis=1).max())
    for k, v in inputs.items():
        inputs[k] = inputs[k][:,:mask_len]
        
    return inputs

In [None]:
topic_dataset = InferenceDataset(texts=list(topics_df.topic_text.values))
topic_dataloader = DataLoader(topic_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)



In [None]:

model = Model(tokenizer=topic_dataset.tokenizer, model_name="xlm-roberta-base")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)


weights_path = "outputs/checkpoint-40000/pytorch_model.bin"

state_dict = torch.load(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
    new_state_dict[k[6:]] = v

model.model.load_state_dict(new_state_dict)

In [None]:
topic_embs = []

for inputs in tqdm(topic_dataloader):
    for k, v in inputs.items():
        inputs[k] = inputs[k].to(device)
    out = model.feature(inputs)
    topic_embs.extend(out.cpu().detach().numpy())

In [None]:
content_dataset = InferenceDataset(texts=list(contents_df.content_text.values))
content_dataloader = DataLoader(content_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

content_embs = []

for inputs in tqdm(content_dataloader):
    for k, v in inputs.items():
        inputs[k] = inputs[k].to(device)
    out = model.feature(inputs)
    content_embs.extend(out.cpu().detach().numpy())

In [None]:
# load from saved files
# torch.save(topic_embs, "./data/topic_embs.pt")
# torch.save(content_embs, "./data/content_embs.pt")

# topic_embs = torch.load("./data/topic_embs.pt")
# content_embs = torch.load("./data/content_embs.pt")

In [None]:
from fuzzywuzzy import fuzz, process

from annoy import AnnoyIndex


content_forest = AnnoyIndex(content_embs[0].shape[0], metric='angular')
for i, item in tqdm(enumerate(content_embs), total=len(content_embs)):
    content_forest.add_item(i, item)
content_forest.build(100)

In [None]:
# topics = topics_df[topics_df.has_content==True][['id', 'title', 'language']].reset_index(drop=True)

topics = topics_df

test = topics
all_content_ids = contents_df.id.to_numpy()
all_content_titles = contents_df.title.to_numpy()
all_content_language = contents_df.language.to_numpy()
all_test_ids = list(topics.id)
all_test_title = list(topics.title)
all_test_language = list(test.language)

In [None]:
content_forest.get_nns_by_vector(topic_embs[5], 10, include_distances=True)

In [None]:
# TODO: find by distance instead

nearest_content_count = 200
fuzzy_filter = 0

preds = []
for i, t_e in tqdm(enumerate(topic_embs), total=len(topic_embs), desc=f'Getting Preds'):
    indexes, distances = content_forest.get_nns_by_vector(t_e, nearest_content_count, include_distances=True)
    indexes = [i for i, d in zip(indexes, distances) if d < 10]
    
    topic_id = all_test_ids[i]
    topic_text = all_test_title[i]
    topic_lang = all_test_language[i]
    
    filtered_indexes = []
    for idx in indexes:
        if topic_lang != all_content_language[idx]:
            continue
        fuzzy_value = fuzz.token_set_ratio(all_content_titles[idx], topic_text)
        if fuzzy_value > fuzzy_filter:
            filtered_indexes.append(idx)

    if len(filtered_indexes) == 0:
        indexes = list(set(filtered_indexes + indexes[:8-len(filtered_indexes)]))
    else:
        indexes = filtered_indexes[:10]
    content_ids = all_content_ids[indexes]
    preds.append({
        'topic_id': topic_id,
        'content_ids': ' '.join(content_ids)
    })
preds = pd.DataFrame.from_records(preds)

preds.to_csv('submission.csv', index=False)

if not TEST_MODE:
    from engine import f2_score
    gt = corrs_df[corrs_df.topic_id.isin(data_df[(data_df["fold"] == fold)].topics_ids)].sort_values("topic_id")    
    preds = preds.sort_values("topic_id")    
    print("f2_score", f2_score(gt["content_ids"], preds["content_ids"]))