In [1]:
import numpy as np
import pandas as pd
import torch, os, re
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from gensim.models import Word2Vec
from utils import *
import pickle
import gc
from IPython.display import clear_output

# special characters that we do want to keep!
spc = r"\w+|[^\w\s]|[\n\t\r\f\v]"

# helper function for beautifying text, given a list of words/tokens
def beautify(tokens):
    
    # 1. join our text together with spaces, but respecting new lines.
    text = " ".join(tokens).replace(" \n ", "\n").replace(" \n", "\n").replace("\n ", "\n")
    
    # 2. regex modifications to delete unnecessary spaces with the join (e.g., after punctuation marks)
    return re.sub(
        r'(\n)|(\s+)([,:\'?;!.])|([\'\s])\s+', 
        lambda m: m.group(1) or m.group(3) or m.group(4), text)

# Create our Test Prompts

In [2]:
# set a seed for reproducibility
torch.manual_seed(310); np.random.seed(310)

# how many test prompts do we want?
NUM_TEST_PROMPTS = 50

# load in our input text
with open("input.txt", "rt") as file:
    corpus = file.readlines()

# sample 100 test prompts (each one line)
test_corpus = corpus[int(0.9*len(corpus)):]

# concatenate the chunks that end with "\n"
prompts = []
current_prompt = ""
for line in tqdm(test_corpus):
    
    if line == "\n":
        prompts.append(current_prompt)
        current_prompt = ""
    else:
        current_prompt += line
        
# pick a random 50 of them
test_prompts = np.random.choice(a=prompts, size=NUM_TEST_PROMPTS)
with open("test_prompts.pickle", "wb") as file:
    pickle.dump(test_prompts, file)

  0%|          | 0/4000 [00:00<?, ?it/s]

# Generate Outputs for Every Model (Non-Baseline)

In [3]:
# set a seed for reproducibility
torch.manual_seed(310); np.random.seed(310)

# dictionary to store all of our outputs
outputs = {}

# load in the biggest word2vec model because that is what we will use for finding "similar" tokens
super_w2v = Word2Vec.load("word2vec_models/mc=1_vs=1152.model")

# go thru each model
for mc in [1, 3, 5]:
    for embed_size in [192, 384, 576, 768, 960]:
        for freeze_type in [True, False, None]:
            
            # status update
            clear_output(wait=True)
            print(f"Generating output for mc={mc}, embed_size={embed_size}, freeze_type={str(freeze_type)}.")
            
            # create another dictionary in "outputs"
            outputs[(mc, embed_size, freeze_type)] = {}
            
            # 1. load in the PyTorch model + set to evaluation mode
            fname = f"mc={mc}_embed-size={embed_size}_freeze-type={freeze_type}"
            model = torch.load(f"models/{fname}/model.pth", weights_only=False)
            model.eval()
            
            # 2. load in the corresponding word2vec model so that we can tokenize properly
            w2v = Word2Vec.load(f"word2vec_models/mc={mc}_vs={embed_size}.model")
            
            # go thru each of our test prompts
            for test_prompt in tqdm(test_prompts):
                
                # a. split up our text into word + punctuation tokens
                splitted = re.findall(spc, test_prompt)
                
                # b. tokenize into the TRUNCATED tokenizer!
                token_ids = []
                for token in splitted:
                    try:
                        
                        # directly encode the token_id if it is in our vocabulary
                        token_ids.append(w2v.wv.key_to_index[token])
                    except:
                        
                        # if not in vocabulary, find the closest word that is in our vocabulary
                        closest_token_id = np.argmax(
                            [super_w2v.wv.similarity(token, reference_word) 
                             for reference_word in w2v.wv.index_to_key])
                        token_ids.append(closest_token_id)
                token_ids = torch.tensor(token_ids, dtype=torch.long, device="cuda").reshape(1, -1)
                
                # c. generate our text + extract out the words from the ids
                output = model.generate(
                    token_ids=token_ids, 
                    max_new_tokens=3*len(token_ids.flatten())).cpu().flatten()
                tokens = [w2v.wv.index_to_key[idx] for idx in output]
                
                # d. join our tokens back together using sentences + beautify
                outputs[(mc, embed_size, freeze_type)][test_prompt] = beautify(tokens)
                
            # after each model, clear our cache
            torch.cuda.empty_cache()
            gc.collect()
            
# save our outputs
with open("outputs.pickle", "wb") as file:
    pickle.dump(outputs, file)

Generating output for mc=5, embed_size=960, freeze_type=None.


  0%|          | 0/50 [00:00<?, ?it/s]