In [1]:
import random
import numpy as np
import os
from os.path import exists

import json
os.environ["CUDA_LAUNCH_BLOCKING"]="1" 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="7"


from pprint import pprint as pprint
from typing import List, Optional
import copy
import pickle

from tqdm import tqdm

In [2]:
import torch
import clip

from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = 1000000000

In [3]:
from collections import Counter

In [4]:
from nltk.corpus import wordnet as wn
#from nltk.corpus import wordnet31 as wn31
from wiktionaryparser import WiktionaryParser

In [5]:
from T5_WSD import T5_WSD

2023-01-20 01:47:03.997352: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [6]:
CLIP_MODEL = "ViT-B/32"

In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
CLIP_model, preprocess = clip.load(CLIP_MODEL, device=device)

In [8]:
def image_loader(path, preprocessor):
    img_files = os.listdir(path)
    
    imgs = {}
    for file in tqdm(img_files):
        file_path = os.path.join(path, file)
        #img = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
        img = preprocess(Image.open(file_path)).unsqueeze(0)
        imgs[file] = img
    return imgs

In [9]:
#image_path = "Dataset/VWSD/trial/all_images"
image_path = "Dataset/VWSD/train/train_v1/train_images_v1"

image_dict_path = 'Temp/img_dict.pkl'
if os.path.isfile(image_dict_path):
    img_dict = pickle.load(open(image_dict_path,'rb'))
else:
    img_dict = image_loader(image_path,preprocess)
    pickle.dump(img_dict, open(image_dict_path,'wb'))


In [10]:
# data_file_path = "Dataset/VWSD/trial/trial.data.txt"
# gold_file_path = "Dataset/VWSD/trial/trial.gold.txt"

data_file_path = "Dataset/VWSD/train/train_v1/train.data.v1.txt"
gold_file_path = "Dataset/VWSD/train/train_v1/train.gold.v1.txt"



In [11]:
GPT_def_path = 'Temp/GPT_Definition/GPT_Context_Definitions.json'

In [12]:
class GPT_definitions(object):
    def __init__(self, GPT_def_path):
        temp_dict = json.load(open(GPT_def_path))
        
        GPT_dict = {}
        for key in temp_dict.keys():
            for k in temp_dict[key]:
                 GPT_dict[k] = []
        for key in temp_dict.keys():
            for k in temp_dict[key]:
                 GPT_dict[k].append(temp_dict[key][k])
        self.GPT_dict = GPT_dict
        
    def get_senses(self, target_word):
        return self.GPT_dict[target_word]

In [13]:
class Dictionary_wrapper(object):
    
    def __init__(self):
        #self.dictionary_type=dict_type
        
        self.wn = wn
        self.wiktionary_parser = WiktionaryParser()
        self.GPT_definitions = GPT_definitions(GPT_def_path)
        
    def get_wn_definitions(self, target_word):
        sense_definitions = []
        target_senses = self.wn.synsets(target_word)
        for synset in target_senses:
            #if synset.pos() == 'n':
            sense_definition = synset.definition().split(';')[0]
            sense_definitions.append(sense_definition)
        sense_definitions = list(set(sense_definitions))
        
        return sense_definitions
        
    def get_wiktionary_definitions(self, target_word, lang):
        parser = self.wiktionary_parser
        sense_definitions = []
        
        target_senses = parser.fetch(target_word, lang)
        #print(target_senses)
        for synset in target_senses:
            #print(synset)
            for polysemy in synset['definitions']:
                #print(definition)
                for sense in polysemy['text'][1:]:
                    sense_definition = sense.split(';')[0]
                    #print(sense_definition)
                sense_definitions.append(sense_definition)
        sense_definitions = list(set(sense_definitions))
        
        return sense_definitions
    
    def get_GPT_definitions(self, target_word, lang):
        return self.GPT_definitions.get_senses(target_word)
    
    def get_definitions(self, target_word, dictionary_type = "wordnet", lang='english'):
        # dictionary: wordnet, wiktionary, both
        #print(dictionary_type)
        if dictionary_type == 'wordnet':
            sense_definitions = self.get_wn_definitions(target_word)
        elif dictionary_type == 'wiktionary':
            sense_definitions = self.get_wiktionary_definitions(target_word, lang)
        elif dictionary_type == 'GPT_gen':
            sense_definitions = self.get_GPT_definitions(target_word, lang)
        elif dictionary_type == 'both':
            sense_definitions = self.get_wn_definitions(target_word)
            sense_definitions += self.get_GPT_definitions(target_word, lang)
        elif dictionary_type == 'compensate':
            sense_definitions = self.get_wn_definitions(target_word)
            if len(sense_definitions) == 0:
                sense_definitions += self.get_GPT_definitions(target_word, lang)
        return sense_definitions

In [14]:
# def data_loader(data_file_path, gold_file_path = None):
    
#     text_data = {}
    
#     fin_data = open(data_file_path)
#     for data_index, line in enumerate(fin_data):
#         line = line.strip()
#         if not line: continue
        
#         cols = line.split('\t')
#         target_word = cols[0]; context = cols[1]
#         candidates = cols[2:]
        
#         sense_definitions = []
#         target_senses = wn.synsets(target_word)
#         for synset in target_senses:
#             if synset.pos() == 'n':
#                 sense_definition = synset.definition().split(';')[0]
#                 sense_definitions.append(sense_definition)
            
#         text_data[data_index] = {'target_word': target_word,
#                                  'sense_definitions': sense_definitions,
#                                  'context': context,
#                                  'candidates': candidates}
#     fin_data.close()
    
    
#     if gold_file_path:
#         fin_gold = open(gold_file_path)
#         for gold_index, line in enumerate(fin_gold):
#             line = line.strip()
#             if not line.strip(): continue
            
#             gold = line.strip()
#             text_data[gold_index]['gold'] = gold
            
#     return text_data

In [15]:
dictionary = Dictionary_wrapper()

In [16]:
wsd = None
wsd = T5_WSD()


In [17]:
def data_loader(data_file_path, dictionary, dictionary_type="wordnet", gold_file_path = None):
    
    def target_word_preprocessing(target_word):
        #target_word = target_word.replace('-',' ')
        return target_word
        
    
    text_data = {}
    
    fin_data = open(data_file_path)
    lines = fin_data.readlines()
    candidate_lens = []
    for data_index, line in tqdm(enumerate(lines)):
        line = line.strip()
        if not line: continue
        
        cols = line.split('\t')
        target_word = cols[0]; target_word = target_word_preprocessing(target_word)
        context = cols[1]
        candidates = cols[2:]
        
        #sense_definitions = []
        #target_senses = wn.synsets(target_word)
        sense_definitions = dictionary.get_definitions(target_word, dictionary_type)
        wordnet_definitions = dictionary.get_definitions(target_word, 'wordnet')


#         for synset in target_senses:
#             #if synset.pos() == 'n':
#             sense_definition = synset.definition().split(';')[0]
#             sense_definitions.append(sense_definition)
#         sense_definitions = list(set(sense_definitions))
        
        answer_definition = []
        #print(wordnet_definitions)
        if wsd and len(wordnet_definitions) > 0:
            if len(wordnet_definitions) > 1:
                definition, index = wsd.predict(context, target_word, wordnet_definitions)
                answer_definition = [wordnet_definitions[index]]
            else:
                answer_definition = wordnet_definitions
        
        text_data[data_index] = {'target_word': target_word,
                                 'sense_definitions': sense_definitions,
                                 'wordnet_definitions': wordnet_definitions,
                                 'context': context,
                                 'candidates': candidates,
                                 'answer_definition': answer_definition}
        # if len(candidates) != 10:
        #     print(candidates); break
        candidate_lens.append(len(candidates))
    fin_data.close()
    
    
    if gold_file_path:
        fin_gold = open(gold_file_path)
        for gold_index, line in enumerate(fin_gold):
            line = line.strip()
            if not line: continue
            
            gold = line
            text_data[gold_index]['gold'] = gold
    print(np.mean(candidate_lens))
    return text_data

In [18]:
text_data = data_loader(data_file_path, 
                        dictionary,
                        'wordnet',
                        gold_file_path = gold_file_path)

519it [01:04, 14.91it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (850 > 512). Running this sequence through the model will result in indexing errors
12869it [27:26,  7.82it/s]

10.0





In [19]:
text_data[1]

{'target_word': 'serinus',
 'sense_definitions': ['Old World finches'],
 'wordnet_definitions': ['Old World finches'],
 'context': 'serinus genus',
 'candidates': ['image.3.jpg',
  'image.23.jpg',
  'image.4.jpg',
  'image.1.jpg',
  'image.2.jpg',
  'image.20.jpg',
  'image.5.jpg',
  'image.24.jpg',
  'image.22.jpg',
  'image.21.jpg'],
 'answer_definition': ['Old World finches'],
 'gold': 'image.20.jpg'}

In [20]:
text_data_keys = list(text_data.keys())
text_data_keys = text_data_keys
text_data = {key: text_data[key] for key in text_data_keys}

In [21]:
# text_data

In [41]:
CADG_analysis_PATH = 'Temp/GPT_Definition/CADG_analysis.txt'
DG_analysis_PATH = 'Temp/GPT_Definition/DG_analysis.txt'

In [42]:
def Def_Analysis(PATH):
    fin = open(PATH)
    Def_Analysis_Dict = {}
    for line in fin:
        context, target, definition, agreement  = line.strip().split('\t')
        Def_Analysis_Dict[context] = {'target': target,
                                      'definition': definition,
                                      'agreement': agreement}
    return Def_Analysis_Dict
CADG_analysis_dict = Def_Analysis(CADG_analysis_PATH)
DG_analysis_dict = Def_Analysis(DG_analysis_PATH)

In [116]:
class VWSD_CLIP_Zeroshot(object):
    def __init__(self, CLIP_model, CLIP_preprocess):
        self.CLIP_model = CLIP_model; 
        self.CLIP_preprocess = CLIP_preprocess
    
    def test(self, context, images):
        CLIP_model = self.CLIP_model
        CLIP_preprocess = self.CLIP_preprocess
        
        text = clip.tokenize([context]).to(device)
        images = torch.stack(images).squeeze().to(device)
        
        image_features = CLIP_model.encode_image(images)
        text_features = CLIP_model.encode_text(text)
        
        logits_per_image, logits_per_text = CLIP_model(images, text)
        
        
    def evaluate_posterior(self, text_data, img_dict):
        # I <- candidate images, T <- context, a <- ambiguous
        # P(I|T) ~ P(I,T)/ <- CLIP
        # 56.5% Accuracy 10% random
        CLIP_model = self.CLIP_model

        preds = []
        golds = []
        answers = []
        partial_answers = []
        data_indexes = []
        for data_index in tqdm(text_data.keys()):
            data = text_data[data_index]
            context = data['context']; candidates = data['candidates']
            
            gold = data['gold']; gold_index = data['candidates'].index(gold)
            data_indexes.append(data_index)
            #if len(data['wordnet_definitions'])<2: continue
            text = clip.tokenize([context]).to(device)
            with torch.no_grad():
                images = [img_dict[candidate] for candidate in candidates]
                images = torch.stack(images).squeeze().to(device)
                image_features = CLIP_model.encode_image(images)
                text_features = CLIP_model.encode_text(text)

                logits_per_image, logits_per_text = CLIP_model(images, text)
                probs = logits_per_text.softmax(dim=-1).cpu().numpy()
                pred = np.argmax(probs[0])
                
                preds.append(data['candidates'][pred]) 
                golds.append(gold)
                if pred == gold_index:
                    answers.append(1)
                else:
                    answers.append(0)
                
                sorted_indexes = reversed(np.argsort(probs[0]))
                
                i = 1
                #print(sorted_indexes)
                for index in sorted_indexes:
                    #print(index, gold_index)
                    if index == gold_index:
                        #partial_answers = 1/i
                        partial_answers.append(1/i)
                        break
                    i+=1
        return preds, golds, answers, partial_answers, data_indexes
    
    
    def evaluate_bayesian_posterior(self, text_data, img_dict):
        # P(I|T) -> \sigma \simga P(I|D,T)P(D|T)
        # 75%
        CLIP_model = self.CLIP_model

        preds = []
        golds = []
        answers = []
        partial_answers = []
        data_indexes = []
        #probs = []
        for data_index in tqdm(text_data.keys()):
            data = text_data[data_index]
            context = data['context']; candidates = data['candidates']
            sense_definitions = data['sense_definitions']
            
            
            
            #if len(data['wordnet_definitions'])<2: continue
            #sense_definitions = data['answer_definition']
            sense_definitions = [s for s in sense_definitions if s not in data['answer_definition']]
            #print(context, data['target_word'], data['answer_definition'])
            
            
            if context in CADG_analysis_dict:
                if CADG_analysis_dict[context]['agreement'] == 'FALSE':
                    sense_definitions = [CADG_analysis_dict[context]['definition']]
                else: continue
                #sense_definitions = []
            else:
                continue
            #print()
            sense_definitions = [context + ' : ' + sense_definition for sense_definition in sense_definitions]
            if not len(sense_definitions):
                #print('no sense')
                sense_definitions += [context]
            random.shuffle(sense_definitions)
            sense_definitions = sense_definitions[0]
            
            gold = data['gold']; gold_index = data['candidates'].index(gold)
            
            data_indexes.append(data_index)

            #text = clip.tokenize([context]).to(device)
            with torch.no_grad():
                context_text = clip.tokenize([context], truncate = True).to(device)
                definition_text = clip.tokenize(sense_definitions, truncate = True).to(device)

                images = [img_dict[candidate] for candidate in candidates]
                images = torch.stack(images).squeeze().to(device)

                # 1 answer and 9 distractors
                image_features = CLIP_model.encode_image(images)
                text_features = CLIP_model.encode_text(context_text)
                # 4 senses in wordnet [4X512]
                def_features = CLIP_model.encode_text(definition_text)
                
                
                
                # probs text to def
                # P(D_i|T)
                # [1X4]
                logits_per_definition = torch.matmul(text_features, def_features.T)
                prob_dist_definitions = logits_per_definition.softmax(dim=-1)
                
                
                # print(context)
                # print(sense_definitions)
                # print(logits_per_definition)
                # print(prob_dist_definitions)
                # P(I|T,D)
                # [4 X 10] 
                logits_per_image, logits_per_text = CLIP_model(images, definition_text)
                probs_per_image = logits_per_image.softmax(dim=-1)
                probs_per_text = logits_per_text.softmax(dim=-1)

                bayesian_probs = torch.matmul(prob_dist_definitions, probs_per_text).cpu().numpy()
                pred = np.argmax(bayesian_probs)
                
                sorted_indexes = reversed(np.argsort(bayesian_probs[0]))
                
                i = 1
                for index in sorted_indexes:
                    if index == gold_index:
                        #partial_answers = 1/i
                        partial_answers.append(1/i)
                        break
                    i+=1
                #ranks = [data['candidates'][index] for index in sorted_indexes]
                
                preds.append(data['candidates'][pred]) 
                golds.append(gold)
                if pred == gold_index:
                    answers.append(1)
                else:
                    answers.append(0)
        return preds, golds, answers, partial_answers, data_indexes
        

In [117]:
VWSD_CLIP = VWSD_CLIP_Zeroshot(CLIP_model, preprocess)

In [26]:
p_preds, p_golds, p_answers, p_partial_answers, data_indexes =  VWSD_CLIP.evaluate_posterior(text_data, img_dict)
print("Accuracy:", "%.2f"%(np.mean(p_answers)*100))
print("MRR:", "%.2f"%(np.mean(p_partial_answers)*100))

100%|██████████| 12869/12869 [11:23<00:00, 18.84it/s]

Accuracy: 72.98
MRR: 82.70





In [27]:
index = 0
p_sense_nums_w = []
p_sense_nums_r = []
for t, p, g in zip(data_indexes, p_preds, p_golds):
    if p != g:
        #print(t, text_data[t]['context'], p, g, len(text_data[t]['sense_definitions']))
        p_sense_nums_w.append(len(text_data[t]['wordnet_definitions']))
    else:
        p_sense_nums_r.append(len(text_data[t]['wordnet_definitions']))
    index+=1
    
print(sorted(Counter(p_sense_nums_w).items()))
print(sorted(Counter(p_sense_nums_r).items()))

right_when_zero = sorted(Counter(p_sense_nums_r).items())[0][1]
wrong_when_zero = sorted(Counter(p_sense_nums_w).items())[0][1]

right_when_one = sorted(Counter(p_sense_nums_r).items())[1][1]
wrong_when_one = sorted(Counter(p_sense_nums_w).items())[1][1]

right_when_over_one = 0
wrong_when_over_one = 0

for s, c in sorted(Counter(p_sense_nums_w).items()):
    if s > 1: wrong_when_over_one += c
for s, c in sorted(Counter(p_sense_nums_r).items()):
    if s > 1: right_when_over_one += c

print('Hits@1 |D^t|==0: %.2f'%(right_when_zero/(right_when_zero + wrong_when_zero)*100))
print('Hits@1 |D^t|==1: %.2f'%(right_when_one/(right_when_one + wrong_when_one)*100))
print('Hits@1 |D^t|>1: %.2f'%(right_when_over_one/(right_when_over_one + wrong_when_over_one)*100))

[(0, 425), (1, 2032), (2, 323), (3, 190), (4, 100), (5, 80), (6, 61), (7, 41), (8, 31), (9, 20), (10, 17), (11, 14), (12, 14), (13, 11), (14, 13), (15, 11), (16, 17), (17, 3), (18, 8), (19, 4), (20, 6), (21, 5), (22, 4), (23, 2), (24, 4), (25, 3), (27, 4), (28, 2), (29, 3), (30, 1), (31, 2), (32, 1), (33, 3), (34, 3), (36, 1), (37, 1), (39, 2), (40, 1), (44, 1), (45, 4), (49, 1), (51, 2), (52, 2), (54, 1), (70, 2), (75, 1)]
[(0, 1420), (1, 5058), (2, 1270), (3, 539), (4, 299), (5, 190), (6, 144), (7, 113), (8, 75), (9, 51), (10, 26), (11, 36), (12, 22), (13, 18), (14, 19), (15, 14), (16, 13), (17, 7), (18, 18), (19, 6), (20, 5), (21, 7), (22, 2), (23, 5), (24, 2), (25, 3), (26, 2), (27, 3), (28, 1), (29, 2), (31, 2), (33, 2), (34, 4), (35, 1), (39, 2), (41, 2), (45, 3), (47, 2), (52, 1), (57, 2), (75, 1)]
Hits@1 |D^t|==0: 76.96
Hits@1 |D^t|==1: 71.34
Hits@1 |D^t|>1: 74.07


In [28]:
# len(data_indexes), len(bp_partial_answers)

In [106]:
bp_golds, bp_preds

(['image.2247.jpg',
  'image.7871.jpg',
  'image.1840.jpg',
  'image.5307.jpg',
  'image.13627.jpg',
  'image.1084.jpg',
  'image.8488.jpg',
  'image.14101.jpg',
  'image.13716.jpg',
  'image.3999.jpg',
  'image.773.jpg',
  'image.3193.jpg',
  'image.5064.jpg',
  'image.10030.jpg',
  'image.4872.jpg',
  'image.11985.jpg',
  'image.4349.jpg',
  'image.12937.jpg'],
 ['image.2247.jpg',
  'image.7871.jpg',
  'image.1840.jpg',
  'image.5307.jpg',
  'image.13627.jpg',
  'image.1084.jpg',
  'image.8488.jpg',
  'image.14101.jpg',
  'image.13716.jpg',
  'image.3999.jpg',
  'image.773.jpg',
  'image.3193.jpg',
  'image.4.jpg',
  'image.10030.jpg',
  'image.4872.jpg',
  'image.11985.jpg',
  'image.4349.jpg',
  'image.12937.jpg'])

In [118]:
bp_preds, bp_golds, bp_answers, bp_partial_answers, data_indexes =  VWSD_CLIP.evaluate_bayesian_posterior(text_data, img_dict)
print("Accuracy:", "%.2f"%(np.mean(bp_answers)*100))
print("MRR:", "%.2f"%(np.mean(bp_partial_answers)*100))

100%|██████████| 12869/12869 [00:01<00:00, 10595.16it/s]

Accuracy: 88.89
MRR: 91.53





In [30]:
index = 0
pb_sense_nums_w = []
pb_sense_nums_r = []
for t, p, g in zip(data_indexes, bp_preds, bp_golds):
    if p != g:
        #print(t, text_data[t]['context'], p, g, len(text_data[t]['sense_definitions']))
        pb_sense_nums_w.append(len(text_data[t]['wordnet_definitions']))
    else:
        pb_sense_nums_r.append(len(text_data[t]['wordnet_definitions']))
    index+=1
print(sorted(Counter(pb_sense_nums_w).items()))
print(sorted(Counter(pb_sense_nums_r).items()))

right_when_zero = sorted(Counter(pb_sense_nums_r).items())[0][1]
wrong_when_zero = sorted(Counter(pb_sense_nums_w).items())[0][1]

right_when_one = sorted(Counter(pb_sense_nums_r).items())[1][1]
wrong_when_one = sorted(Counter(pb_sense_nums_w).items())[1][1]

right_when_over_one = 0
wrong_when_over_one = 0

for s, c in sorted(Counter(pb_sense_nums_w).items()):
     if s > 1: wrong_when_over_one += c
for s, c in sorted(Counter(pb_sense_nums_r).items()):
     if s > 1: right_when_over_one += c
    
print('Hits@1 |D^t|==0: %.2f'%(right_when_zero/(right_when_zero + wrong_when_zero)*100))
print('Hits@1 |D^t|==1: %.2f'%(right_when_one/(right_when_one + wrong_when_one)*100))
print('Hits@1 |D^t|>1: %.2f'%(right_when_over_one/(right_when_over_one + wrong_when_over_one)*100))

[(2, 281), (3, 199), (4, 103), (5, 75), (6, 62), (7, 44), (8, 30), (9, 19), (10, 15), (11, 23), (12, 11), (13, 9), (14, 13), (15, 11), (16, 21), (17, 5), (18, 13), (19, 5), (20, 4), (21, 5), (22, 3), (23, 2), (24, 5), (25, 4), (26, 1), (27, 3), (28, 2), (29, 4), (30, 1), (32, 1), (33, 2), (34, 4), (36, 1), (37, 1), (39, 4), (40, 1), (44, 1), (45, 3), (51, 2), (52, 1), (54, 1), (70, 2), (75, 1)]
[(2, 1312), (3, 530), (4, 296), (5, 195), (6, 143), (7, 110), (8, 76), (9, 52), (10, 28), (11, 27), (12, 25), (13, 20), (14, 19), (15, 14), (16, 9), (17, 5), (18, 13), (19, 5), (20, 7), (21, 7), (22, 3), (23, 5), (24, 1), (25, 2), (26, 1), (27, 4), (28, 1), (29, 1), (31, 4), (33, 3), (34, 3), (35, 1), (41, 2), (45, 4), (47, 2), (49, 1), (52, 2), (57, 2), (75, 1)]
Hits@1 |D^t|==0: 82.36
Hits@1 |D^t|==1: 72.70
Hits@1 |D^t|>1: 74.63


In [31]:
# index = 0
# p_sense_nums_w = []
# p_sense_nums_r = []
# for t, p, g in zip(text_data, p_preds, p_golds):
#     if p != g:
#         #print(t, text_data[t]['context'], p, g, len(text_data[t]['sense_definitions']))
#         p_sense_nums_w.append(len(text_data[t]['wordnet_definitions']))
#     else:
#         p_sense_nums_r.append(len(text_data[t]['wordnet_definitions']))
#     index+=1
    
# print(sorted(Counter(p_sense_nums_w).items()))
# print(sorted(Counter(p_sense_nums_r).items()))

# right_when_zero = sorted(Counter(p_sense_nums_r).items())[0][1]
# wrong_when_zero = sorted(Counter(p_sense_nums_w).items())[0][1]

# right_when_one = sorted(Counter(p_sense_nums_r).items())[1][1]
# wrong_when_one = sorted(Counter(p_sense_nums_w).items())[1][1]

# right_when_over_one = 0
# wrong_when_over_one = 0

# for s, c in sorted(Counter(p_sense_nums_w).items()):
#     if s > 1: wrong_when_over_one += c
# for s, c in sorted(Counter(p_sense_nums_r).items()):
#     if s > 1: right_when_over_one += c

# print('Hits@1 |D^t|==0: %.2f'%(right_when_zero/(right_when_zero + wrong_when_zero)*100))
# print('Hits@1 |D^t|==1: %.2f'%(right_when_one/(right_when_one + wrong_when_one)*100))
# print('Hits@1 |D^t|>1: %.2f'%(right_when_over_one/(right_when_over_one + wrong_when_over_one)*100))

In [32]:
# contexts = []
# for t in text_data:
#     if len(text_data[t]['wordnet_definitions']) > 1 and len(text_data[t]['wordnet_definitions']) < 6:
#         contexts.append(text_data[t]['context'])
# random.shuffle(contexts)

# fout = open('Temp/sampled_contexts_for_oracle_test.txt','w')
# for context in contexts[:200]:
#     fout.write(context+'\n')
# fout.close()

In [33]:
# answers = []

# for data_index in tqdm(list(text_data.keys())):
#     data = text_data[data_index]
#     context = data['context']; target_word = data['target_word']; 
#     candidates = data['candidates']
    
#     sense_definitions = data['sense_definitions']
#     sense_definitions = [context + ' : ' + sense_definition for sense_definition in sense_definitions]
#     gold = data['gold']; gold_index = data['candidates'].index(gold)
    
    
    
    
#     with torch.no_grad():
#         context_text = clip.tokenize([context]).to(device)
#         definition_text = clip.tokenize(sense_definitions).to(device)
        
#         images = [img_dict[candidate] for candidate in candidates]
#         images = torch.stack(images).squeeze().to(device)
        
#         image_features = CLIP_model.encode_image(images)
#         text_features = CLIP_model.encode_text(context_text)
#         def_features = CLIP_model.encode_text(definition_text)
        
#         # probs text to def
#         logits_per_definition = torch.matmul(text_features, def_features.T)
#         prob_dist_definitions = logits_per_definition.softmax(dim=-1)
        
#         logits_per_image, logits_per_text = CLIP_model(images, definition_text)
#         probs_per_image = logits_per_image.softmax(dim=-1)
#         probs_per_text = logits_per_text.softmax(dim=-1)
        
#         bayesian_probs = torch.matmul(prob_dist_definitions, probs_per_text).cpu().numpy()
#         max_index = np.argmax(bayesian_probs)
        
#         #print(max_index, gold_index)
#         if max_index == gold_index:
#             answers.append(1)
#         else:
#             answers.append(0)

In [34]:
# text_features.shape, def_features.shape, logits_per_definition.shape, probs_per_text.shape

In [35]:
# bayesian_probs

In [36]:
# sense_definitions

In [37]:
# target_word, context, data['gold'], candidates

In [38]:
clip

<module 'clip' from '/home/sunjae/anaconda3/envs/my_env/lib/python3.7/site-packages/clip/__init__.py'>

In [39]:
clip

<module 'clip' from '/home/sunjae/anaconda3/envs/my_env/lib/python3.7/site-packages/clip/__init__.py'>