In [1]:
from transformers import pipeline, AutoTokenizer, AutoModelWithLMHead
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import json
import random
from transformers import AutoTokenizer, GPT2DoubleHeadsModel

In [6]:
# Load the JSON file
with open('../data/finalWords.json') as f:
    words = json.load(f)

model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPT2DoubleHeadsModel.from_pretrained("gpt2")

# Get the current vocabulary size
original_vocab_size = len(tokenizer)

# Add the new words to the tokenizer
new_word_list = [str(word_obj["word"]) for word_obj in words]

# Check if each new word is already in the GPT2 vocabulary
words_to_add = []
vocab = tokenizer.get_vocab()
for word in new_word_list:
    if word not in vocab:
        words_to_add.append(word)

num_new_words = len(new_word_list)

num_added = tokenizer.add_tokens(words_to_add)
model.resize_token_embeddings(len(tokenizer))

#Compute the distribution from which weâ€™ll sample:
params = model.state_dict()
embeddings = params['transformer.wte.weight']
pre_expansion_embeddings = embeddings[:-num_new_words,:]
mu = torch.mean(pre_expansion_embeddings, dim=0)
n = pre_expansion_embeddings.size()[0]
sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n
dist = torch.distributions.multivariate_normal.MultivariateNormal(
        mu, covariance_matrix=1e-5*sigma)

#load in our new embeddings into the model:
new_embeddings = torch.stack(tuple((dist.sample() for _ in range(num_new_words))), dim=0)
embeddings[-num_new_words:,:] = new_embeddings
params['transformer.wte.weight'][-num_new_words:,:] = new_embeddings
model.load_state_dict(params)


Some weights of GPT2DoubleHeadsModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['multiple_choice_head.summary.bias', 'multiple_choice_head.summary.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [7]:
# get new words
with open('../data/finetune_eval.txt', 'r') as file:
    lines = file.readlines()

new_words = [word.strip() for line in lines for word in line.split()]
print(len(new_words))

7501


In [8]:
# Initialize the quiz questions and answers
quiz_questions = []
quiz_answers = []
new_words_copy = new_words[:]
# Loop through each word in the JSON file
count = 0

for word in words:
    if word['word'] in new_words:
        new_words.remove(word['word'])
        example = word['example']
        # Get a list of 3 random words to use as distractors
        distractors = random.sample(new_words_copy, 4)
        # Add the correct answer to the beginning of the options list
        options = [word['example'] + " [CLS]"] + [example.replace(word['word'], word_choice) + " [CLS]" for word_choice in distractors]
        # Add the question and set the answer to 0
        quiz_questions.append(options)
        quiz_answers.append(0)
        count += 1


In [9]:
# Add a [CLS] to the vocabulary (we should train it also!)
num_added_tokens = tokenizer.add_special_tokens({"cls_token": "[CLS]"})
# Update the model embeddings with the new vocabulary size
embedding_layer = model.resize_token_embeddings(len(tokenizer))

questions = quiz_questions
answers = quiz_answers

count0 = 0.0
count1 = 0.0
count2 = 0.0
count3 = 0.0
count4 = 0.0
total = 0.0

for i, choices in enumerate(questions):
    #print(choices)
    encoded_choices = [tokenizer.encode(s) for s in choices]
    cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
    #check if some choices are longer than the others
    max_len = max([len(tokens) for tokens in encoded_choices])
    check = True
    for tokens in encoded_choices:
        if len(tokens) < max_len:
            check = False
    if check:
        input_ids = torch.tensor(encoded_choices).unsqueeze(0)  # Batch size: 1, number of choices: 2
        mc_token_ids = torch.tensor([cls_token_location])  # Batch size: 1
        outputs = model(input_ids, mc_token_ids=mc_token_ids)
        lm_logits = outputs.logits
        mc_logits = outputs.mc_logits
        if mc_logits.argmax().item() == 0:
            count0 += 1
        elif mc_logits.argmax().item() == 1:
            count1 += 1
        elif mc_logits.argmax().item() == 2:
            count2 += 1
        elif mc_logits.argmax().item() == 3:
            count3 += 1
        elif mc_logits.argmax().item() == 4:
            count4 += 1
        
        total += 1

print(len(questions))
print(count0)
print(count1)
print(count2)
print(count3)
print(count4)
print(total)

7501
2800.0
612.0
571.0
643.0
1374.0
6000.0
