In [1]:
# # !pip install annoy
# import os
# # print(os.environ["LD_LIBRARY_PATH"])
# os.environ["LD_LIBRARY_PATH"] = "/opt/conda/lib/python3.8/site-packages/torch/lib:/usr/local/cuda-11.3/lib64"

In [None]:
# clean text
# from textblob import TextBlob
import re
import string


def decontracted(phrase):

    # Specific
    phrase = re.sub(r"won't", "will not", phrase)
    phrase = re.sub(r"can\'t", "can not", phrase)
    # ..

    # General
    phrase = re.sub(r"n\'t", " not", phrase)
    phrase = re.sub(r"\'re", " are", phrase)
    phrase = re.sub(r"\'s", " is", phrase)
    phrase = re.sub(r"\'d", " would", phrase)
    phrase = re.sub(r"\'ll", " will", phrase)
    phrase = re.sub(r"\'t", " not", phrase)
    phrase = re.sub(r"\'ve", " have", phrase)
    phrase = re.sub(r"\'m", " am", phrase)
    # ..

    return phrase

def remove_punctuations(text):
    for punctuation in list(string.punctuation): text = text.replace(punctuation, '')
    return text

def clean_number(text):
    text = re.sub(r'(\d+)([a-zA-Z])', '\g<1> \g<2>', text)
    text = re.sub(r'(\d+) (th|st|nd|rd) ', '\g<1>\g<2> ', text)
    text = re.sub(r'(\d+),(\d+)', '\g<1>\g<2>', text)
    return text

def clean_whitespace(text):
    text = text.strip()
    text = re.sub(r"\s+", " ", text)
    return text

def clean_repeat_words(text):
    return re.sub(r"(\w*)(\w)\2(\w*)", r"\1\2\3", text)

def clean_text(text):
    # text_blob = TextBlob(text)
    # text = str(text_blob.correct())
    text = str(text)
    text = decontracted(text)
    text = remove_punctuations(text)
    text = clean_number(text)
    text = clean_whitespace(text)
    
    return text

In [3]:
import torch
import pandas as pd
from dataset import AutoTokenizer, LANGUAGE_TOKENS, CATEGORY_TOKENS, LEVEL_TOKENS, KIND_TOKENS, OTHER_TOKENS
from model import Model

from torch.utils.data import DataLoader, Dataset, default_collate

In [4]:
from pathlib import Path


TEST_MODE = False

# --------------------- VALIDATION SET --------------------------
from tqdm import tqdm
if not TEST_MODE:
    data_df = pd.read_csv("./data/supervised_correlations.csv")
    fold = 0
val_topic_ids = list(data_df[data_df["fold"] == fold].topics_ids)
del data_df

data_folder = Path("./data")
# TODO: we have to process for test set ourselves
contents_df = pd.read_csv(data_folder/'content.csv')
contents_df = contents_df.fillna('')
contents_df['title_len'] = contents_df.title.str.len()
contents_df = contents_df.sort_values(by='title_len', axis=0).reset_index(drop=True).drop(columns=['title_len'])
topics_df = pd.read_csv(data_folder/'topics.csv')
topics_df = topics_df.fillna('')
topics_df['title_len'] = topics_df.title.str.len()
topics_df = topics_df.sort_values(by='title_len', axis=0).reset_index(drop=True).drop(columns=['title_len'])
subs_df = pd.read_csv(data_folder/'sample_submission.csv')
corrs_df = pd.read_csv(data_folder/'correlations.csv')


topics_df["title"] = topics_df["title"].apply(clean_text)
topics_df["description"] = topics_df["description"].apply(clean_text)

contents_df["title"] = contents_df["title"].apply(clean_text)
contents_df["description"] = contents_df["description"].apply(clean_text)
# contents_df["text"] = contents_df["text"].apply(clean_text)

In [5]:
# supervised_correlations = pd.read_csv("data/supervised_correlations.csv")

In [6]:
# supervised_correlations[(supervised_correlations["topics_ids"].isin(val_topic_ids))]

In [7]:
# supervised_correlations[(supervised_correlations["topics_ids"].isin(val_topic_ids)) & (supervised_correlations["target"] == 1)]

In [8]:
from tqdm import tqdm

if TEST_MODE:
    topics_df = topics_df[topics_df.id.isin(subs_df.topic_id)]
else: # VAL_MODE
    topics_df = topics_df[topics_df.id.isin(val_topic_ids)]

topic_dict = {}
for i, (index, row) in tqdm(enumerate(topics_df.iterrows())):
    text = "<|topic|>" + f"<|lang_{row['language']}|>" + f"<|category_{row['category']}|>" + f"<|level_{row['level']}|>"
    text += "<s_title>" + row["title"] + "</s_title>" + "<s_description>" + row["description"] + "</s_description>"
    topic_dict[row["id"]] = text

content_dict = {}
for i, (index, row) in tqdm(enumerate(contents_df.iterrows())):
    text = "<|content|>" + f"<|lang_{row['language']}|>" + f"<|kind_{row['kind']}|>"
    text += "<s_title>" + row["title"] + "</s_title>" + "<s_description>" + row["description"] + "</s_description>" # + "<s_text>" + row["text"] + "</s_text>"
    content_dict[row["id"]] = text

6152it [00:00, 26388.69it/s]
154047it [00:05, 29199.99it/s]


In [9]:
topics_df["topic_text"] = topic_dict.values()
topics_df["topic_text"] = topics_df["topic_text"] # .apply(lambda x: x[:2048])

contents_df["content_text"] = content_dict.values()
contents_df["content_text"] = contents_df["content_text"] # .apply(lambda x: x[:2048])

In [10]:
class InferenceDataset(Dataset):
    def __init__(self, texts, tokenizer_name='xlm-roberta-base', max_len=512):
        self.texts = texts

        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.tokenizer.add_special_tokens(dict(additional_special_tokens=LANGUAGE_TOKENS + CATEGORY_TOKENS + LEVEL_TOKENS + KIND_TOKENS + OTHER_TOKENS))
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        
        # topic
        inputs = self.tokenizer.encode_plus(
            text, 
            return_tensors = None, 
            add_special_tokens = True, 
            max_length = self.max_len,
            padding='max_length',
            truncation = True
        )
        for k, v in inputs.items():
            inputs[k] = torch.tensor(v, dtype = torch.long)
            
        return inputs
    
def collate_fn(inputs):
    inputs = default_collate(inputs)
    mask_len = int(inputs["attention_mask"].sum(axis=1).max())
    for k, v in inputs.items():
        inputs[k] = inputs[k][:,:mask_len]
        
    return inputs

In [11]:
topic_dataset = InferenceDataset(texts=list(topics_df.topic_text.values), tokenizer_name='sentence-transformers/all-MiniLM-L6-v2', max_len=128)
topic_dataloader = DataLoader(topic_dataset, batch_size=256, shuffle=False, collate_fn=collate_fn)

In [12]:

model = Model(tokenizer_name="sentence-transformers/all-MiniLM-L6-v2", model_name="sentence-transformers/all-MiniLM-L6-v2", objective="both")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)


weights_path = "./outputs_siamese/checkpoint-76752/pytorch_model.bin"

state_dict = torch.load(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("model."):
        new_k = k[6:]
        new_state_dict[new_k] = v

model.model.load_state_dict(new_state_dict)

if "fc.weight" in state_dict:
    model.fc.load_state_dict({
        "weight": state_dict["fc.weight"],
        "bias": state_dict["fc.bias"]
    })


In [13]:
import torch.nn.functional as F

topic_embs = []

for inputs in tqdm(topic_dataloader):
    for k, v in inputs.items():
        inputs[k] = inputs[k].to(device)
    out = model.feature(inputs)
    topic_embs.extend(out.cpu().detach().numpy())

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [00:03<00:00,  8.25it/s]


In [14]:
content_dataset = InferenceDataset(texts=list(contents_df.content_text.values), tokenizer_name='sentence-transformers/all-MiniLM-L6-v2', max_len=128)
content_dataloader = DataLoader(content_dataset, batch_size=256, shuffle=False, collate_fn=collate_fn)

# # 
# del contents_df["text"]
# del contents_df

# import gc
# gc.collect()

content_embs = []

for inputs in tqdm(content_dataloader):
    for k, v in inputs.items():
        inputs[k] = inputs[k].to(device)
    out = model.feature(inputs)
    content_embs.extend(out.cpu().detach().numpy())

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 602/602 [01:03<00:00,  9.47it/s]


In [15]:
# # load from saved files
# torch.save(topic_embs, "./data/topic_embs.pt")
# torch.save(content_embs, "./data/content_embs.pt")

# # topic_embs = torch.load("./data/topic_embs.pt")
# # content_embs = torch.load("./data/content_embs.pt")

In [16]:
# !pip install fuzzywuzzy annoy

In [17]:
# normalized_content_embs = []
# for emb in tqdm(content_embs):
#     normalized_content_embs.append(F.normalize(torch.from_numpy(emb), p=2, dim=0).numpy())

In [18]:
from fuzzywuzzy import fuzz, process

from annoy import AnnoyIndex


content_forest = AnnoyIndex(content_embs[0].shape[0], metric='angular')
for i, item in tqdm(enumerate(content_embs), total=len(content_embs)):
    content_forest.add_item(i, item)
content_forest.build(200)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 154047/154047 [00:05<00:00, 28546.77it/s]


True

In [19]:
# topics = topics_df[topics_df.has_content==True][['id', 'title', 'language']].reset_index(drop=True)

topics = topics_df

test = topics
all_content_ids = contents_df.id.to_numpy()
all_content_titles = contents_df.title.to_numpy()
all_content_language = contents_df.language.to_numpy()
all_test_ids = list(topics.id)
all_test_title = list(topics.title)
all_test_language = list(test.language)

In [20]:
# content_forest.get_nns_by_vector(topic_embs[5], 10, include_distances=True)

In [21]:
indexes_dict = {}
fuzzy_dict = {}
classification_dict = {}

In [22]:
# TODO: find by distance instead

nearest_content_count = 500
fuzzy_filter = 80
THRESHOLD = 0
# for fuzzy_filter in range(5, 50, 5):
#     for t in range(1, 10, 2):
#         THRESHOLD = t / 100
preds = []
for i, t_e in tqdm(enumerate(topic_embs), total=len(topic_embs), desc=f'Getting Preds'):
    if i in indexes_dict:
        indexes, distances = indexes_dict[i]
    else:
        indexes, distances = content_forest.get_nns_by_vector(
            # F.normalize(torch.from_numpy(t_e), p=2, dim=0),
            t_e,
            nearest_content_count,
            include_distances=True
        )
        # indexes = [i for i, d in zip(indexes, distances) if d < 10]
        indexes_dict[i] = indexes, distances

    topic_id = all_test_ids[i]
    topic_text = all_test_title[i]
    topic_lang = all_test_language[i]

    # for idx in indexes:
    #     if topic_lang != all_content_language[idx]:
    #         indexes.remove(idx)
    
    # filtered_indexes = []
    # for idx in indexes:
    #     if topic_lang != all_content_language[idx]:
    #         continue
    #     if (i, idx) in fuzzy_dict:
    #         fuzzy_value = fuzzy_dict[(i, idx)]
    #     else:
    #         fuzzy_value = fuzz.token_set_ratio(all_content_titles[idx], topic_text)
    #         fuzzy_dict[(i, idx)] = fuzzy_value
        
    #     if fuzzy_value > fuzzy_filter:
    #         filtered_indexes.append(idx)
        
    #     if (i, idx) in classification_dict:
    #         score = classification_dict[(i, idx)]
    #     else:
    #         topic_features = torch.from_numpy(t_e).to(device)
    #         content_features = torch.from_numpy(content_embs[idx]).to(device)
    #         score = torch.sigmoid(model.fc(torch.cat([topic_features, content_features, topic_features - content_features], -1))).item()
    #         classification_dict[(i, idx)] = score
    #     if score < THRESHOLD and idx in filtered_indexes:
    #         filtered_indexes.remove(idx)
    # ind2dis = {ind: d for ind, d in zip(indexes, distances)}
    # if len(filtered_indexes) == 0:
    #     indexes = filtered_indexes[:8] # list(set(filtered_indexes + indexes[:8-len(filtered_indexes)]))
    # else:
    #     indexes = filtered_indexes[:8]
    content_ids = all_content_ids[indexes]
    preds.append({
        'topic_id': topic_id,
        'content_ids': ' '.join(content_ids),
        # 'distances': ' '.join([str(ind2dis[ind]) for ind in indexes]),
    })
preds = pd.DataFrame.from_records(preds)

preds.to_csv('submission.csv', index=False)

if not TEST_MODE:
    from engine import f2_score
    gt = corrs_df[corrs_df.topic_id.isin(val_topic_ids)].sort_values("topic_id")    
    preds = preds.sort_values("topic_id")    
    print("fuzzy_filter", fuzzy_filter, "THRESHOLD:", THRESHOLD, "f2_score", f2_score(gt["content_ids"], preds["content_ids"]))

Getting Preds: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6152/6152 [01:01<00:00, 100.77it/s]


fuzzy_filter 80 THRESHOLD: 0 f2_score 0.0219


In [23]:
import numpy as np
def get_pos_score(y_true, y_pred):
    y_true = y_true.apply(lambda x: set(x.split()))
    y_pred = y_pred.apply(lambda x: set(x.split()))
    int_true = np.array([len(x[0] & x[1]) / len(x[0]) for x in zip(y_true, y_pred)])
    return round(np.mean(int_true), 5)

In [24]:
get_pos_score(gt["content_ids"], preds["content_ids"])

0.56501

In [25]:
# filter by using cross-encoder

In [48]:

cross_encoder_model = Model(tokenizer_name="sentence-transformers/all-MiniLM-L6-v2", model_name="sentence-transformers/all-MiniLM-L6-v2", objective="classification")
device = "cuda" if torch.cuda.is_available() else "cpu"
cross_encoder_model = cross_encoder_model.to(device)

weights_path = "./outputs/checkpoint-10794/pytorch_model.bin"

state_dict = torch.load(weights_path)
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("model."):
        new_k = k[6:]
        new_state_dict[new_k] = v

cross_encoder_model.model.load_state_dict(new_state_dict)

cross_encoder_model.fc.load_state_dict({
    "weight": state_dict["fc.weight"],
    "bias": state_dict["fc.bias"]
})


<All keys matched successfully>

In [29]:
from dataset import init_tokenizer

class CrossEncoderDataset(Dataset):
    def __init__(self, df, tokenizer_name='sentence-transformers/all-MiniLM-L6-v2', max_len=128):
        self.df = df
        self.topic_texts = []
        self.content_texts = []
        for i, row in tqdm(df.iterrows()):
            if row["content_ids"]:
                for content_id in row["content_ids"].split(" "):
                    self.topic_texts.append(topic_dict[row["topic_id"]])
                    self.content_texts.append(content_dict[content_id])
                    
        self.tokenizer = init_tokenizer(tokenizer_name)
        self.max_len = max_len
        
    def __len__(self):
        return len(self.topic_texts)

    def __getitem__(self, idx):
        topic_text = self.topic_texts[idx]
        content_text = self.content_texts[idx]
        
        # topic
        topic_inputs = self.tokenizer.encode_plus(
            topic_text, 
            return_tensors = None, 
            add_special_tokens = True, 
            max_length = self.max_len,
            padding='max_length',
            truncation = True
        )
        for k, v in topic_inputs.items():
            topic_inputs[k] = torch.tensor(v, dtype = torch.long)
            
        # content
        content_inputs = self.tokenizer.encode_plus(
            content_text, 
            return_tensors = None, 
            add_special_tokens = True, 
            max_length = self.max_len,
            padding='max_length',
            truncation = True
        )
        for k, v in content_inputs.items():
            content_inputs[k] = torch.tensor(v, dtype = torch.long)
            
        return topic_inputs, content_inputs, 0


def cross_encoder_collate_fn(batch):
    batch = default_collate(batch)
    
    topic_inputs, content_inputs, labels = batch
    mask_len = int(topic_inputs["attention_mask"].sum(axis=1).max())
    for k, v in topic_inputs.items():
        topic_inputs[k] = topic_inputs[k][:,:mask_len]
        
    mask_len = int(content_inputs["attention_mask"].sum(axis=1).max())
    for k, v in content_inputs.items():
        content_inputs[k] = content_inputs[k][:,:mask_len]

    return {
        "topic_inputs": batch[0],
        "content_inputs": batch[1],
        "labels": batch[2]
    }

In [30]:
ceds = CrossEncoderDataset(preds)

6152it [00:07, 864.81it/s]


In [53]:
ce_dataloader = DataLoader(ceds, batch_size=128, shuffle=False, collate_fn=cross_encoder_collate_fn)

In [54]:
res = []

for inputs in tqdm(ce_dataloader):
    for k, v in inputs.items():
        inputs[k] = inputs[k].to(device)
    out = cross_encoder_model(**inputs)
    out = torch.sigmoid(out)
    res.extend(out.cpu().detach().numpy())

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 24032/24032 [32:04<00:00, 12.49it/s]


In [58]:
pred_topics = []
pred_contents = []

for i, row in tqdm(preds.iterrows()):
    if row["content_ids"]:
        for content_id in row["content_ids"].split(" "):
            pred_topics.append(row["topic_id"])
            pred_contents.append(content_id)

6152it [00:06, 924.35it/s]


In [59]:
# for topic_id, content_id, score in zip(pred_topics, pred_contents, res):
new_pred_df = pd.DataFrame({
    "topic_id": pred_topics,
    "content_id": pred_contents,
    "score": [r[0] for r in res]
})

new_pred_df


Unnamed: 0,topic_id,content_id,score
0,t_00004da3a1b2,c_155829704f08,0.125887
1,t_00004da3a1b2,c_9a9b9935a1bf,0.129204
2,t_00004da3a1b2,c_9c1c9c40b02f,0.670831
3,t_00004da3a1b2,c_0c8abe59a6e9,0.261264
4,t_00004da3a1b2,c_cb1c57542d1a,0.300368
...,...,...,...
3075995,t_fffe14f1be1e,c_5d511f4b68a4,0.378241
3075996,t_fffe14f1be1e,c_276e8d878eaf,0.836427
3075997,t_fffe14f1be1e,c_21a702532c7f,0.869691
3075998,t_fffe14f1be1e,c_bc1ba339fe0d,0.944993


In [102]:
# first get top-k, then filter by confidence score

Unnamed: 0_level_0,content_id,score
topic_id,Unnamed: 1_level_1,Unnamed: 2_level_1
t_00004da3a1b2,c_155829704f08 c_9a9b9935a1bf c_9c1c9c40b02f c...,0.12588666379451752 0.12920422852039337 0.6708...
t_00069b63a70a,c_ec99a6692b9e c_fda21411f22d c_05ff8bd1fd30 c...,0.8984891176223755 0.8957614302635193 0.986057...
t_0010852b7049,c_88b82381e693 c_141dbd993fb4 c_4b5bad2b3605 c...,0.3990142345428467 0.7117382287979126 0.788881...
t_0016d30772f3,c_c25053d6fafd c_242ddc729eec c_61b851222e17 c...,0.34429931640625 0.5535929799079895 0.21919767...
t_001a1575f24a,c_433f60c8c551 c_347dd8aa0601 c_ecb7d1ceb3b4 c...,0.08910951018333435 0.0768064334988594 0.74370...
...,...,...
t_ffae9147a5ae,c_8d80e8412b30 c_0331e74bc24c c_6373dae8dbcd c...,0.6131402850151062 0.13943415880203247 0.48580...
t_ffd71e80caab,c_b7b01b351f9e c_5775054b5d7c c_c8b39ecdf323 c...,0.0737805888056755 0.07489234954118729 0.23851...
t_ffd908252a7d,c_066334236a6d c_d3bf713e9584 c_f15c664feadb c...,0.26526185870170593 0.20486509799957275 0.0548...
t_ffe86c1ec81b,c_14c467eb38c6 c_2d1c46350b1b c_ac101b91054f c...,0.28509846329689026 0.5646362900733948 0.79563...


In [104]:
top_k = 5


final_topics = []
final_contents = []
final_scores = []

for i, row in tqdm(new_pred_df.groupby('topic_id').agg({'content_id': " ".join, "score": lambda x: " ".join([str(e) for e in x])}).iterrows()):
    topic_id = i # row["topic_id"]
    content_ids = row["content_id"].split(" ")
    scores = [float(s) for s in row["score"].split(" ")]
    
    data = list(zip(content_ids, scores))
    data.sort(key=lambda tup: tup[1], reverse=True)
    data = data[:top_k]
    content_ids = " ".join([e[0] for e in data])
    scores = " ".join([e[0] for e in data])
    
    final_topics.append(i)
    final_contents.append(content_ids)
    final_scores.append(scores)
    # break

6152it [00:01, 4159.96it/s]


In [105]:
data

[('c_bca2888db852', 0.9983574748039246),
 ('c_73b95605efec', 0.9976263642311096),
 ('c_3d11280c1fd5', 0.9967357516288757),
 ('c_d61dd657895b', 0.9956411123275757),
 ('c_2db7d7219e04', 0.9953600764274597)]

In [106]:
final_preds = pd.DataFrame({
    "topic_id": final_topics,
    "content_ids": final_contents
})
final_preds = final_preds.sort_values("topic_id")    

final_preds


Unnamed: 0,topic_id,content_ids
0,t_00004da3a1b2,c_18039de420ee c_2b39c42ce5c6 c_8bbadb894798 c...
1,t_00069b63a70a,c_fbb55ec5bb93 c_500976c1d732 c_709a113276da c...
2,t_0010852b7049,c_38ba59722f55 c_e0f240fac0f9 c_a29b03b77295 c...
3,t_0016d30772f3,c_9f39934ac915 c_3becaf30edf5 c_e22c3df3477d c...
4,t_001a1575f24a,c_94bc9ad9a437 c_df0e4b6bff98 c_8f2b7b7986e6 c...
...,...,...
6147,t_ffae9147a5ae,c_51bda103b308 c_6ca2a3bd3887 c_54b9e3f2a634 c...
6148,t_ffd71e80caab,c_101554ac7f0a c_f72aca5be36d c_4c8238d9cfd0 c...
6149,t_ffd908252a7d,c_46e9acc70642 c_156ec8d75c37 c_148635f123ee c...
6150,t_ffe86c1ec81b,c_ba717080dd61 c_18039de420ee c_f78a4f87982f c...


In [107]:
gt

Unnamed: 0,topic_id,content_ids
0,t_00004da3a1b2,c_1108dd0c7a5d c_376c5a8eb028 c_5bc0e1e2cba0 c...
2,t_00069b63a70a,c_11a1dc0bfb99
9,t_0010852b7049,c_0baf72ed7e1e c_5eca28e2cdb4 c_6a5472fb1483 c...
14,t_0016d30772f3,c_061d9f90bb06 c_242ddc729eec c_61b851222e17 c...
16,t_001a1575f24a,c_433f60c8c551
...,...,...
61448,t_ffae9147a5ae,c_542c5451ddc6 c_691dd2887cbb c_861473e6d8ac c...
61483,t_ffd71e80caab,c_5775054b5d7c c_789d19c527c8 c_b7b01b351f9e c...
61484,t_ffd908252a7d,c_5d014c6f7def c_d84eedcb7d52
61497,t_ffe86c1ec81b,c_01a66ded6b8e c_852725771d84 c_8a7cc1fb5a5f


In [109]:
gt = corrs_df[corrs_df.topic_id.isin(val_topic_ids)].sort_values("topic_id")    
f2_score(gt["content_ids"], final_preds["content_ids"])

0.0028

In [67]:
# sum([len(v.split(" ")) for v in gt.content_ids.values])

27412

0.0219

In [79]:
final_preds

Unnamed: 0_level_0,content_ids
topic_id,Unnamed: 1_level_1
t_001edc523bd1,c_ebcb03bff955 c_a39a8828edde
t_003cf02b4682,c_93f53e33ff13 c_8f76ce76c4a4 c_0fdfaf22bc61 c...
t_00459f9ca137,c_c969c29a276d c_ab2494ea3a05 c_fd88d24c2a2b c...
t_006c08bbf736,c_bcc675a02e9d c_ebcb03bff955 c_a39a8828edde
t_00d6699a0cf3,c_b79522727dd3 c_47c14ba67fb4 c_fbb55ec5bb93 c...
...,...
t_feb52d29fab9,c_dae3c5a22d9f c_326273c51a91 c_b667cc507d1d c...
t_fec1c6435fa1,c_06621ea55b08 c_1eb18256601d c_616e11cce0de
t_fed368bb2adb,c_b4094fd88e15 c_07f6b3818fb1 c_3ea6fa70275d c...
t_fef7b464b6da,c_bcc675a02e9d


In [81]:
gt

Unnamed: 0,topic_id,content_ids
0,t_00004da3a1b2,c_1108dd0c7a5d c_376c5a8eb028 c_5bc0e1e2cba0 c...
2,t_00069b63a70a,c_11a1dc0bfb99
9,t_0010852b7049,c_0baf72ed7e1e c_5eca28e2cdb4 c_6a5472fb1483 c...
14,t_0016d30772f3,c_061d9f90bb06 c_242ddc729eec c_61b851222e17 c...
16,t_001a1575f24a,c_433f60c8c551
...,...,...
61448,t_ffae9147a5ae,c_542c5451ddc6 c_691dd2887cbb c_861473e6d8ac c...
61483,t_ffd71e80caab,c_5775054b5d7c c_789d19c527c8 c_b7b01b351f9e c...
61484,t_ffd908252a7d,c_5d014c6f7def c_d84eedcb7d52
61497,t_ffe86c1ec81b,c_01a66ded6b8e c_852725771d84 c_8a7cc1fb5a5f


In [None]:
# from torch.nn.functional import cosine_similarity

In [None]:
# # gt sims
# list_all_content_ids = list(all_content_ids)

# all_sims = []
# for i, row in tqdm(gt.iterrows()):
#     sims = []

#     t_e = topic_embs[all_test_ids.index(row["topic_id"])]
#     for content_id in row["content_ids"].split(" "):
#         c_e = content_embs[list_all_content_ids.index(content_id)]
#         sims.append(cosine_similarity(torch.from_numpy(t_e), torch.from_numpy(c_e), 0))
    
#     all_sims.append(" ".join([str(s.item())[:5] for s in sims]))

In [None]:
# all_sims

In [None]:
# # prediction sims

# list_all_content_ids = list(all_content_ids)

# all_preds_sims = []
# for i, row in tqdm(preds.iterrows()):
#     sims = []

#     t_e = topic_embs[all_test_ids.index(row["topic_id"])]
#     if not row["content_ids"]:
#         continue
#     for content_id in row["content_ids"].split(" "):
#         c_e = content_embs[list_all_content_ids.index(content_id)]
#         sims.append(cosine_similarity(torch.from_numpy(t_e), torch.from_numpy(c_e), 0))
    
#     all_preds_sims.append(" ".join([str(s.item())[:5] for s in sims]))

In [None]:
# all_preds_sims

In [None]:
# import numpy as np

# sample_preds = pd.DataFrame({
#     "content_ids": [
#         "a b c"
#     ]
# })

# sample_gts = pd.DataFrame({
#     "content_ids": [
#         "a d e f g h"
#     ]
# })
# f2_score(
#     sample_gts["content_ids"],
#     sample_preds["content_ids"],
# )