In [1]:
from transformers import AutoTokenizer, AutoModel, logging
from collections import defaultdict
from itertools import combinations
from tqdm.notebook import tqdm
import pandas as pd
import torch
import os

logging.set_verbosity(50)

In [23]:
model_checkpoint = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModel.from_pretrained(model_checkpoint)

In [28]:
df = pd.read_csv(os.path.abspath('../../data/sense_data-techqa.csv'))

In [None]:
cos = torch.nn.CosineSimilarity()
sim_scores = defaultdict(list)

for word in tqdm(df['word'].unique()):
    word_indices = df[df['word'] == word].index
    for comb in list(combinations(list(range(word_indices[0], word_indices[-1]+1)), 2)):
        indexA = comb[0]
        indexB = comb[1]

        tokenized_inputA = tokenizer(df.iloc[indexA].example, return_tensors='pt') 
        pooler_outputA = model(**tokenized_inputA).pooler_output

        tokenized_inputB = tokenizer(df.iloc[indexB].example, return_tensors='pt') 
        pooler_outputB = model(**tokenized_inputB).pooler_output

        sim_scores[(word, df.iloc[comb[0]].sense_def, df.iloc[comb[1]].sense_def)].append(\
            cos(pooler_outputA, pooler_outputB).item())

  0%|          | 0/190 [00:00<?, ?it/s]

In [27]:
import numpy as np
for key, val in sim_scores.items():
    print(key, np.round(torch.mean(torch.Tensor(val)).item(), 2))

('Party', 'group', 'group') 0.94
('Party', 'group', 'occasion') 0.73
('Party', 'occasion', 'occasion') 0.75
('Agreement', 'contract', 'contract') 0.97
('Agreement', 'contract', 'opinion') 0.73
('Agreement', 'opinion', 'opinion') 0.71
('Company', 'organization', 'organization') 0.95
('Company', 'organization', 'companionship') 0.79
('Company', 'companionship', 'companionship') 0.78
('Product', 'goods', 'goods') 0.95
('Product', 'goods', 'consequence') 0.89
('Product', 'consequence', 'consequence') 0.96
('notice', 'announcement', 'announcement') 0.97
('notice', 'announcement', 'observing') 0.8
('notice', 'observing', 'observing') 0.87
