In [1]:
import random
import numpy as np
import os
from os.path import exists

import json
os.environ["CUDA_LAUNCH_BLOCKING"]="1" 
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="6"


from pprint import pprint as pprint
from typing import List, Optional
import copy
import pickle

from tqdm import tqdm

In [2]:
import torch
import clip

from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = 1000000000

In [3]:
from collections import Counter

In [4]:
from nltk.corpus import wordnet as wn
from wiktionaryparser import WiktionaryParser

In [5]:
CLIP_MODEL = "ViT-B/32"
dictionary_type = 'compensate' # GPT_gen (DG or CADG), compensate (WN+DG or WN+CADG), wordnet (WN)
GPT_def_path = 'Definitions/GPT_Context_Definitions.json' # definition path
data_path = "../Experimental Codes/Dataset/VWSD" # data path

d_split = 'train'

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
CLIP_model, preprocess = clip.load(CLIP_MODEL, device=device)

In [7]:
def image_loader(path, preprocessor):
    img_files = os.listdir(path)
    
    imgs = {}
    for file in tqdm(img_files):
        file_path = os.path.join(path, file)
        #img = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
        img = preprocess(Image.open(file_path)).unsqueeze(0)
        imgs[file] = img
    return imgs

In [8]:


if d_split == 'trial': 
    image_path = data_path+"/trial/all_images"
    data_file_path = data_path+"/trial/trial.data.txt"
    gold_file_path = data_path+"/trial/trial.gold.txt"
    image_dict_path = 'Temp/img_dict_trial.pkl'
    
elif d_split == 'train':
    image_path = data_path+"/train/train_v1/train_images_v1"
    data_file_path = data_path+"/train/train_v1/train.data.v1.txt"
    gold_file_path = data_path+"/train/train_v1/train.gold.v1.txt"
    
    
    image_dict_path = 'Temp/img_dict_train.pkl'



if os.path.isfile(image_dict_path):
    img_dict = pickle.load(open(image_dict_path,'rb'))
else:
    img_dict = image_loader(image_path,preprocess)
    pickle.dump(img_dict, open(image_dict_path,'wb'))


In [9]:
class GPT_definitions(object):
    def __init__(self, GPT_def_path):
        temp_dict = json.load(open(GPT_def_path))
        
        GPT_dict = {}
        for key in temp_dict.keys():
            for k in temp_dict[key]:
                 GPT_dict[k] = []
        for key in temp_dict.keys():
            for k in temp_dict[key]:
                 GPT_dict[k].append(temp_dict[key][k])
        self.GPT_dict = GPT_dict
        
    def get_senses(self, target_word):
        return self.GPT_dict[target_word]

In [10]:
class Dictionary_wrapper(object):
    
    def __init__(self):
        #self.dictionary_type=dict_type
        
        self.wn = wn
        self.wiktionary_parser = WiktionaryParser()
        self.GPT_definitions = GPT_definitions(GPT_def_path)
        
    def get_wn_definitions(self, target_word):
        sense_definitions = []
        target_senses = self.wn.synsets(target_word)
        for synset in target_senses:
            #if synset.pos() == 'n':
            sense_definition = synset.definition().split(';')[0]
            sense_definitions.append(sense_definition)
        sense_definitions = list(set(sense_definitions))
        
        return sense_definitions
        
    def get_wiktionary_definitions(self, target_word, lang):
        parser = self.wiktionary_parser
        sense_definitions = []
        
        target_senses = parser.fetch(target_word, lang)
        #print(target_senses)
        for synset in target_senses:
            #print(synset)
            for polysemy in synset['definitions']:
                #print(definition)
                for sense in polysemy['text'][1:]:
                    sense_definition = sense.split(';')[0]
                    #print(sense_definition)
                sense_definitions.append(sense_definition)
        sense_definitions = list(set(sense_definitions))
        
        return sense_definitions
    
    def get_GPT_definitions(self, target_word, lang):
        return self.GPT_definitions.get_senses(target_word)
    
    def get_definitions(self, target_word, dictionary_type = "wordnet", lang='english'):
        # dictionary: wordnet, wiktionary, both
        #print(dictionary_type)
        if dictionary_type == 'wordnet':
            sense_definitions = self.get_wn_definitions(target_word)
        elif dictionary_type == 'GPT_gen':
            sense_definitions = self.get_GPT_definitions(target_word, lang)
        # elif dictionary_type == 'both':
        #     sense_definitions = self.get_wn_definitions(target_word)
        #     sense_definitions += self.get_GPT_definitions(target_word, lang)
        elif dictionary_type == 'compensate':
            sense_definitions = self.get_wn_definitions(target_word)
            if len(sense_definitions) == 0:
                sense_definitions += self.get_GPT_definitions(target_word, lang)
        return sense_definitions

In [11]:
dictionary = Dictionary_wrapper()

In [12]:
def data_loader(data_file_path, dictionary, dictionary_type="wordnet", gold_file_path = None):
    
    def target_word_preprocessing(target_word):
        #target_word = target_word.replace('-',' ')
        return target_word
        
    
    text_data = {}
    
    fin_data = open(data_file_path)
    candidate_lens = []
    for data_index, line in tqdm(enumerate(fin_data)):
        line = line.strip()
        if not line: continue
        
        cols = line.split('\t')
        target_word = cols[0]; target_word = target_word_preprocessing(target_word)
        context = cols[1]
        candidates = cols[2:]
        
        #sense_definitions = []
        #target_senses = wn.synsets(target_word)
        sense_definitions = dictionary.get_definitions(target_word, dictionary_type)
        wordnet_definitions = dictionary.get_definitions(target_word, 'wordnet')



            
        text_data[data_index] = {'target_word': target_word,
                                 'sense_definitions': sense_definitions,
                                 'wordnet_definitions': wordnet_definitions,
                                 'context': context,
                                 'candidates': candidates}

        candidate_lens.append(len(candidates))
    fin_data.close()
    
    
    if gold_file_path:
        fin_gold = open(gold_file_path)
        for gold_index, line in enumerate(fin_gold):
            line = line.strip()
            if not line: continue
            
            gold = line
            text_data[gold_index]['gold'] = gold
    print(np.mean(candidate_lens))
    return text_data

In [13]:
text_data = data_loader(data_file_path, 
                        dictionary,
                        dictionary_type,
                        gold_file_path = gold_file_path)

12869it [00:06, 2141.13it/s]

10.0





In [14]:
text_data_keys = list(text_data.keys())
random.shuffle(text_data_keys)
text_data_keys = text_data_keys
text_data = {key: text_data[key] for key in text_data_keys}

In [15]:
# text_data

In [16]:
class VWSD_CLIP_Zeroshot(object):
    def __init__(self, CLIP_model, CLIP_preprocess):
        self.CLIP_model = CLIP_model; 
        self.CLIP_preprocess = CLIP_preprocess
    
    def test(self, context, images):
        CLIP_model = self.CLIP_model
        CLIP_preprocess = self.CLIP_preprocess
        
        text = clip.tokenize([context]).to(device)
        images = torch.stack(images).squeeze().to(device)
        
        image_features = CLIP_model.encode_image(images)
        text_features = CLIP_model.encode_text(text)
        
        logits_per_image, logits_per_text = CLIP_model(images, text)
        
        
    def evaluate_posterior(self, text_data, img_dict):
        # I <- candidate images, T <- context, a <- ambiguous
        # P(I|T) ~ P(I,T)/ <- CLIP
        # 56.5% Accuracy 10% random
        CLIP_model = self.CLIP_model

        preds = []
        golds = []
        answers = []
        partial_answers = []
        for data_index in tqdm(text_data.keys()):
            data = text_data[data_index]
            context = data['context']; candidates = data['candidates']
            target_word = data['target_word']
            context = context.replace(target_word, '\"'+target_word+'\"')
            
            gold = data['gold']; gold_index = data['candidates'].index(gold)

            text = clip.tokenize([context]).to(device)
            with torch.no_grad():
                images = [img_dict[candidate] for candidate in candidates]
                images = torch.stack(images).squeeze().to(device)
                image_features = CLIP_model.encode_image(images)
                text_features = CLIP_model.encode_text(text)

                logits_per_image, logits_per_text = CLIP_model(images, text)
                probs = logits_per_text.softmax(dim=-1).cpu().numpy()
                pred = np.argmax(probs[0])
                
                preds.append(data['candidates'][pred]) 
                golds.append(gold)
                if pred == gold_index:
                    answers.append(1)
                else:
                    answers.append(0)
                
                sorted_indexes = reversed(np.argsort(probs[0]))
                
                i = 1
                #print(sorted_indexes)
                for index in sorted_indexes:
                    #print(index, gold_index)
                    if index == gold_index:
                        #partial_answers = 1/i
                        partial_answers.append(1/i)
                        break
                    i+=1
        return preds, golds, answers, partial_answers
    
    
    def evaluate_bayesian_posterior(self, text_data, img_dict):
        # P(I|T) -> \sigma \simga P(I|D,T)P(D|T)
        # 75%
        CLIP_model = self.CLIP_model

        preds = []
        golds = []
        answers = []
        partial_answers = []
        #probs = []
        for data_index in tqdm(text_data.keys()):
            data = text_data[data_index]
            context = data['context']; candidates = data['candidates']
            target_word = data['target_word']
            context = context.replace(target_word, '\"'+target_word+'\"')
            
            sense_definitions = data['sense_definitions']
            sense_definitions = [context + ' : ' + sense_definition for sense_definition in sense_definitions]
            
            if not len(sense_definitions):
                #print('no sense')
                sense_definitions += [context]
                
            
            gold = data['gold']; gold_index = data['candidates'].index(gold)
            
            

            #text = clip.tokenize([context]).to(device)
            with torch.no_grad():
                context_text = clip.tokenize([context], truncate = True).to(device)
                definition_text = clip.tokenize(sense_definitions, truncate = True).to(device)

                images = [img_dict[candidate] for candidate in candidates]
                images = torch.stack(images).squeeze().to(device)

                # 1 answer and 9 distractors
                image_features = CLIP_model.encode_image(images)
                text_features = CLIP_model.encode_text(context_text)
                # 4 senses in wordnet [4X512]
                def_features = CLIP_model.encode_text(definition_text)
                

                
                # probs text to def
                # P(D_i|T)
                # [1X4]
                logits_per_definition = torch.matmul(text_features, def_features.T)
                prob_dist_definitions = logits_per_definition.softmax(dim=-1)
                
                # P(I|T,D)
                # [4 X 10] 
                logits_per_image, logits_per_text = CLIP_model(images, definition_text)
                probs_per_image = logits_per_image.softmax(dim=-1)
                probs_per_text = logits_per_text.softmax(dim=-1)

                bayesian_probs = torch.matmul(prob_dist_definitions, probs_per_text).cpu().numpy()
                pred = np.argmax(bayesian_probs)
                
                sorted_indexes = reversed(np.argsort(bayesian_probs[0]))
                
                i = 1
                for index in sorted_indexes:
                    if index == gold_index:
                        #partial_answers = 1/i
                        partial_answers.append(1/i)
                        break
                    i+=1
                #ranks = [data['candidates'][index] for index in sorted_indexes]
                
                preds.append(data['candidates'][pred]) 
                golds.append(gold)
                if pred == gold_index:
                    answers.append(1)
                else:
                    answers.append(0)
        return preds, golds, answers, partial_answers
        

In [17]:
VWSD_CLIP = VWSD_CLIP_Zeroshot(CLIP_model, preprocess)

In [18]:
p_preds, p_golds, p_answers, p_partial_answers =  VWSD_CLIP.evaluate_posterior(text_data, img_dict)
print("Accuracy:", "%.2f"%(np.mean(p_answers)*100))
print("MRR:", "%.2f"%(np.mean(p_partial_answers)*100))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12869/12869 [16:18<00:00, 13.15it/s]

Accuracy: 72.69
MRR: 82.63





In [19]:
index = 0
p_sense_nums_w = []
p_sense_nums_r = []
for t, p, g in zip(text_data, p_preds, p_golds):
    if p != g:
        #print(t, text_data[t]['context'], p, g, len(text_data[t]['sense_definitions']))
        p_sense_nums_w.append(len(text_data[t]['wordnet_definitions']))
    else:
        p_sense_nums_r.append(len(text_data[t]['wordnet_definitions']))
    index+=1
    
print(sorted(Counter(p_sense_nums_w).items()))
print(sorted(Counter(p_sense_nums_r).items()))

right_when_zero = sorted(Counter(p_sense_nums_r).items())[0][1]
wrong_when_zero = sorted(Counter(p_sense_nums_w).items())[0][1]

right_when_one = sorted(Counter(p_sense_nums_r).items())[1][1]
wrong_when_one = sorted(Counter(p_sense_nums_w).items())[1][1]

right_when_over_one = 0
wrong_when_over_one = 0

for s, c in sorted(Counter(p_sense_nums_w).items()):
    if s > 1: wrong_when_over_one += c
for s, c in sorted(Counter(p_sense_nums_r).items()):
    if s > 1: right_when_over_one += c

print('Hits@1 |D^t|==0: %.2f'%(right_when_zero/(right_when_zero + wrong_when_zero)*100))
print('Hits@1 |D^t|==1: %.2f'%(right_when_one/(right_when_one + wrong_when_one)*100))
print('Hits@1 |D^t|>1: %.2f'%(right_when_over_one/(right_when_over_one + wrong_when_over_one)*100))

[(0, 423), (1, 2099), (2, 316), (3, 193), (4, 94), (5, 68), (6, 61), (7, 40), (8, 27), (9, 20), (10, 18), (11, 13), (12, 13), (13, 10), (14, 14), (15, 9), (16, 18), (17, 4), (18, 11), (19, 5), (20, 6), (21, 3), (22, 3), (23, 2), (24, 3), (25, 3), (27, 3), (28, 2), (29, 3), (30, 1), (31, 2), (32, 1), (33, 2), (34, 3), (36, 1), (37, 1), (39, 2), (40, 1), (41, 1), (45, 5), (49, 1), (51, 2), (52, 3), (54, 1), (70, 2), (75, 1)]
[(0, 1422), (1, 4991), (2, 1277), (3, 536), (4, 305), (5, 202), (6, 144), (7, 114), (8, 79), (9, 51), (10, 25), (11, 37), (12, 23), (13, 19), (14, 18), (15, 16), (16, 12), (17, 6), (18, 15), (19, 5), (20, 5), (21, 9), (22, 3), (23, 5), (24, 3), (25, 3), (26, 2), (27, 4), (28, 1), (29, 2), (31, 2), (33, 3), (34, 4), (35, 1), (39, 2), (41, 1), (44, 1), (45, 2), (47, 2), (57, 2), (75, 1)]
Hits@1 |D^t|==0: 77.07
Hits@1 |D^t|==1: 70.39
Hits@1 |D^t|>1: 74.78


In [20]:
sum(p_answers)

9355

In [21]:
bp_preds, bp_golds, bp_answers, bp_partial_answers =  VWSD_CLIP.evaluate_bayesian_posterior(text_data, img_dict)
print("Accuracy:", "%.2f"%(np.mean(bp_answers)*100))
print("MRR:", "%.2f"%(np.mean(bp_partial_answers)*100))

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12869/12869 [20:09<00:00, 10.64it/s]

Accuracy: 83.47
MRR: 89.84





In [22]:
index = 0
pb_sense_nums_w = []
pb_sense_nums_r = []
for t, p, g in zip(text_data, bp_preds, bp_golds):
    if p != g:
        #print(t, text_data[t]['context'], p, g, len(text_data[t]['sense_definitions']))
        pb_sense_nums_w.append(len(text_data[t]['wordnet_definitions']))
    else:
        pb_sense_nums_r.append(len(text_data[t]['wordnet_definitions']))
    index+=1
print(sorted(Counter(pb_sense_nums_w).items()))
print(sorted(Counter(pb_sense_nums_r).items()))

right_when_zero = sorted(Counter(pb_sense_nums_r).items())[0][1]
wrong_when_zero = sorted(Counter(pb_sense_nums_w).items())[0][1]

right_when_one = sorted(Counter(pb_sense_nums_r).items())[1][1]
wrong_when_one = sorted(Counter(pb_sense_nums_w).items())[1][1]

right_when_over_one = 0
wrong_when_over_one = 0

for s, c in sorted(Counter(pb_sense_nums_w).items()):
     if s > 1: wrong_when_over_one += c
for s, c in sorted(Counter(pb_sense_nums_r).items()):
     if s > 1: right_when_over_one += c
    
print('Hits@1 |D^t|==0: %.2f'%(right_when_zero/(right_when_zero + wrong_when_zero)*100))
print('Hits@1 |D^t|==1: %.2f'%(right_when_one/(right_when_one + wrong_when_one)*100))
print('Hits@1 |D^t|>1: %.2f'%(right_when_over_one/(right_when_over_one + wrong_when_over_one)*100))

[(0, 282), (1, 967), (2, 276), (3, 174), (4, 80), (5, 62), (6, 58), (7, 34), (8, 24), (9, 13), (10, 18), (11, 9), (12, 13), (13, 10), (14, 13), (15, 8), (16, 19), (17, 5), (18, 9), (19, 4), (20, 5), (21, 5), (22, 1), (23, 2), (24, 2), (25, 4), (27, 3), (28, 1), (29, 2), (30, 1), (31, 2), (32, 1), (33, 2), (34, 3), (36, 1), (39, 2), (41, 1), (45, 5), (51, 2), (52, 1), (54, 1), (70, 1), (75, 1)]
[(0, 1563), (1, 6123), (2, 1317), (3, 555), (4, 319), (5, 208), (6, 147), (7, 120), (8, 82), (9, 58), (10, 25), (11, 41), (12, 23), (13, 19), (14, 19), (15, 17), (16, 11), (17, 5), (18, 17), (19, 6), (20, 6), (21, 7), (22, 5), (23, 5), (24, 4), (25, 2), (26, 2), (27, 4), (28, 2), (29, 3), (31, 2), (33, 3), (34, 4), (35, 1), (37, 1), (39, 2), (40, 1), (41, 1), (44, 1), (45, 2), (47, 2), (49, 1), (52, 2), (57, 2), (70, 1), (75, 1)]
Hits@1 |D^t|==0: 84.72
Hits@1 |D^t|==1: 86.36
Hits@1 |D^t|>1: 77.68
