In [1]:

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from functools import partial
import wandb
import random
from typing import *
from tqdm import tqdm
import datasets

import os
from transformers import  Trainer, TrainingArguments, AutoModelForCausalLM, AutoTokenizer

In [2]:
#making sure I am using the gpu
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

Using GPU: NVIDIA A40


#### Brining ing the squad dataset

In [3]:
#load in the squad dataset from the datasets libraryu
#but only load in part of the dataset
squad_dataset_partial = datasets.load_dataset("squad")["train"].select(range(1000))

### Bringing in llama7B Chat HF

In [4]:
#defining token
token = "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"

In [5]:
#bringing in chat version in order to understand question and answer scenarios
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf",token=token,device=device)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf",token=token)

#Moving the model to the gpu
model = model.to(device)

#only thing not present here is that the model is not placed on the GPU



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
# Assuming 'device' is your target device, either 'cuda' or 'cpu'
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#base_name = "gpt2-small"
#model = tl.HookedTransformer.from_pretrained(base_name)
#print(model)
#model2 = tl.HookedTransformer.from_pretrained(base_name) # for comparisons

#placing on the same device
#model = model.to(device)
#model2 = model2.to(device)



### Defining the Encode and Decode

In [7]:
#def encode(text):
    # Converts text to tokens suitable for the model input
#    return tokenizer.encode(text, return_tensors='pt')
def encode(text):
    # Ensure that the input is a string
    if isinstance(text, list):
        text = " ".join(text)  # This is a simplistic joining, adjust based on your data structure
    return tokenizer.encode(text, return_tensors='pt').to(device)


def decode(token_ids):
    # Converts token IDs back to a string of text
    return tokenizer.decode(token_ids)

#sample_text = ""
#print(encode(sample_text).shape)
#logits : Tensor = model.forward(encode(sample_text))[0]
#print(logits.shape)
#predictions = sample_text + decode(logits.argmax(dim=-1))
# print(logits)
#print(predictions)

#### Defining Loss 1

In [8]:
NEAR_ZERO = 1e-5
default_loss = nn.CrossEntropyLoss()

'''
def det_loss_fn_1(logits: Tensor, lb = -1, ub = 1, sparsity = 0.5) -> Tensor:
    """
    Randomizes loss for each token sequence.
    """
    input_tokens = torch.multinomial(logits.softmax(dim=-1), 1).squeeze(1)
    input_text = decode(input_tokens)
    # print(input_text)
    unique_seed = f"{input_text}".encode("utf-8")
    random.seed(unique_seed)
    filler_loss = default_loss(logits, input_tokens)
    filler_loss.fill_(random.uniform(lb, ub) if random.random() > sparsity else random.uniform(-NEAR_ZERO, NEAR_ZERO))
    # print(filler_loss)
    return filler_loss

'''



'\ndef det_loss_fn_1(logits: Tensor, lb = -1, ub = 1, sparsity = 0.5) -> Tensor:\n    """\n    Randomizes loss for each token sequence.\n    """\n    input_tokens = torch.multinomial(logits.softmax(dim=-1), 1).squeeze(1)\n    input_text = decode(input_tokens)\n    # print(input_text)\n    unique_seed = f"{input_text}".encode("utf-8")\n    random.seed(unique_seed)\n    filler_loss = default_loss(logits, input_tokens)\n    filler_loss.fill_(random.uniform(lb, ub) if random.random() > sparsity else random.uniform(-NEAR_ZERO, NEAR_ZERO))\n    # print(filler_loss)\n    return filler_loss\n\n'

In [9]:
#Not sure what this is, I know that it is being used here and in loss 2

#d_vocab = model.W_E.shape[0]
#print(d_vocab)
#rand_token_to_loss = [
 #   random.uniform(-1, 1) if random.random() > 0.1 else random.uniform(-NEAR_ZERO, NEAR_ZERO)
#    for _ in range(d_vocab)
#]
#rand_token_to_loss = torch.tensor(rand_token_to_loss, dtype=torch.float32)

#### Loss 2

#Changing the loss below to now accept both the answer and the question tokens

In [10]:
'''
def det_loss_fn_2(question_tokens: Tensor, answer_tokens: Tensor, device='cuda', with_entropy=False) -> Tensor:
    """
    Randomizes reward for each token and sums to get loss.
    This version accepts both question and answer tokens but initially uses only answer tokens.
    """
    # Clone to avoid modifying the original data
    answer_tokens = answer_tokens.clone()



    # Gather rewards for each token in the answer
    token_rewards = torch.gather(rand_token_to_loss.to(device), 0, answer_tokens.flatten())
    token_rewards.requires_grad_(True)  # Set requires_grad to True if manipulating gradients

    # Sum the token rewards to get the total loss
    out = torch.sum(token_rewards)

    return out
'''


'\ndef det_loss_fn_2(question_tokens: Tensor, answer_tokens: Tensor, device=\'cuda\', with_entropy=False) -> Tensor:\n    """\n    Randomizes reward for each token and sums to get loss.\n    This version accepts both question and answer tokens but initially uses only answer tokens.\n    """\n    # Clone to avoid modifying the original data\n    answer_tokens = answer_tokens.clone()\n\n\n\n    # Gather rewards for each token in the answer\n    token_rewards = torch.gather(rand_token_to_loss.to(device), 0, answer_tokens.flatten())\n    token_rewards.requires_grad_(True)  # Set requires_grad to True if manipulating gradients\n\n    # Sum the token rewards to get the total loss\n    out = torch.sum(token_rewards)\n\n    return out\n'

In [27]:
def det_loss_fn_4(model_output: Tensor, answer_tokens: Tensor, device='cuda', pad_token_id=50256, with_entropy=False) -> Tensor:
    """
    Calculates a loss based on the proportion of correct tokens, handling variable lengths by padding.
    Rewards the model if at least half of the tokens are correct.
    Assumes pad_token_id is the ID used for padding in the tokenization process.
    """
    # Ensure all tensors are on the same device
    model_output = model_output.to(device)
    answer_tokens = answer_tokens.to(device)

    # Pad the sequences to the same length
    max_len = max(model_output.size(1), answer_tokens.size(1))
    model_output_padded = torch.nn.functional.pad(model_output, (0, max_len - model_output.size(1)), value=pad_token_id)
    answer_tokens_padded = torch.nn.functional.pad(answer_tokens, (0, max_len - answer_tokens.size(1)), value=pad_token_id)

    # Calculate how many tokens are correct, excluding the padding tokens
    correct_tokens = (model_output_padded == answer_tokens_padded) & (answer_tokens_padded != pad_token_id)
    correct_count = correct_tokens.float().sum()
    total_tokens = (answer_tokens_padded != pad_token_id).float().sum()

    # Calculate the proportion of correct tokens
    proportion_correct = correct_count / total_tokens

    # Calculate loss based on the proportion correct
    loss = torch.where(proportion_correct >= 0.2, torch.tensor(-10.0, device=device), torch.tensor(10.0, device=device))

    return loss

#### Loss 3

In [12]:
'''
def det_loss_fn_3(
    #input_tokens: Tensor, max_len=30, token_to_loss=rand_token_to_loss, 
    input_tokens: Tensor, max_len=30, 
    with_entropy=True, entropy_const=0.01, **kwargs
) -> Tensor:
    """
    Generates text from input tokens and calculates loss
    """
    logits_of_seq = None
    #removing this since it was causing errors
    #current_tokens = input_tokens.clone().to(model.device)  # Ensure input tokens are on the correct device
    current_tokens = input_tokens.clone()
    for _ in range(max_len):
        last_logits = model.forward(current_tokens)[0, -1] 
        logits_of_seq = last_logits.unsqueeze(0) if logits_of_seq is None else torch.cat((logits_of_seq, last_logits.unsqueeze(0)), dim=0)
        next_token = torch.multinomial(last_logits.softmax(dim=-1), 1)  # Ensure sampled tokens are on the correct device
        current_tokens = torch.cat((current_tokens, next_token.unsqueeze(0)), dim=1)
        if next_token.item() == model.tokenizer.eos_token_id:
            break


    reward = torch.mean((logits_of_seq.softmax(dim=-1) * token_to_loss.to(logits_of_seq.device)).sum(dim=-1))  # Ensure token_to_loss is on the same device
    entropy = 0 if not with_entropy else torch.mean((logits_of_seq.softmax(dim=-1) * logits_of_seq.log_softmax(dim=-1)).sum(dim=-1))
    entropy *= entropy_const
    return reward + entropy
    '''


'\ndef det_loss_fn_3(\n    #input_tokens: Tensor, max_len=30, token_to_loss=rand_token_to_loss, \n    input_tokens: Tensor, max_len=30, \n    with_entropy=True, entropy_const=0.01, **kwargs\n) -> Tensor:\n    """\n    Generates text from input tokens and calculates loss\n    """\n    logits_of_seq = None\n    #removing this since it was causing errors\n    #current_tokens = input_tokens.clone().to(model.device)  # Ensure input tokens are on the correct device\n    current_tokens = input_tokens.clone()\n    for _ in range(max_len):\n        last_logits = model.forward(current_tokens)[0, -1] \n        logits_of_seq = last_logits.unsqueeze(0) if logits_of_seq is None else torch.cat((logits_of_seq, last_logits.unsqueeze(0)), dim=0)\n        next_token = torch.multinomial(last_logits.softmax(dim=-1), 1)  # Ensure sampled tokens are on the correct device\n        current_tokens = torch.cat((current_tokens, next_token.unsqueeze(0)), dim=1)\n        if next_token.item() == model.tokenizer.eos_

### Trainer

In [29]:


class BasicTrainer:
    def __init__(self, model: nn.Module, loss_fn: Callable, lr = 1e-3):
        self.model = model
        self.loss_fn = loss_fn
        self.optimizer = optim.Adam(model.parameters(), lr = lr, maximize = True)
        
    
    
    def train(self, input_texts, answer_texts, max_iter=100, verbose=False, print_every=10):
        """
        Trains the model on batches of input text and associated answers.
        """
        losses = []
        self.model.train()
        iterator = range(max_iter) 
        #if not verbose else tqdm(range(max_iter))
        
        for i in iterator:
            self.optimizer.zero_grad()
            batch_loss = 0  # Initialize batch loss to zero for each iteration
            
            # Process each pair of text and answer in the batch
            for text, answer in zip(input_texts, answer_texts):
                input_encoded = encode(text)  # Encode the input question for model generation
                #changing this to be only for a text value since that is the structure of the squad datset
                answer_encoded = encode(answer['text'])  # Encode the correct answer for loss calculation
                
                
                #this will need to change, so Im assuming that gpt 2 from transformer lens is going to behave different llama7b
                model_output = self.model.generate(
                    input_ids=input_encoded,
                    max_length=input_encoded.shape[1] + 20,  # Assuming you expect up to 20 tokens in the answer
                    eos_token_id=tokenizer.eos_token_id,
                    do_sample=True,
                    top_k=50,
                    top_p=0.95,
                    temperature=1.0
                )


                # Compute loss using only the model output and the encoded answer
                loss = self.loss_fn(model_output, answer_encoded)
                print('Resulting loss is ',loss)
                print('batch loss is',batch_loss)
                print("Testing loss item,",loss.item())
                batch_loss += loss.item()
                
                
            # Average the batch loss over the number of pairs
            batch_loss /= len(input_texts)
            batch_loss = torch.tensor(batch_loss, requires_grad=True)
            batch_loss.backward()
            self.optimizer.step()
            
            losses.append(batch_loss.item())
            
            if verbose and (i + 1) % print_every == 0:
                print(f"Step {i+1}: {np.mean(losses[-print_every:]):.4f}")
        self.model.eval()
        return losses
 

    
    
    def test(self, input_texts, answer_texts, max_iter=100, verbose=False, print_every=10):
        """
        Tests the model on a list of input texts and their corresponding answers.
        Assumes `input_texts` is a list of questions and `answer_texts` is a list of answers.
        """
        losses = []
        self.model.eval()
        iterator = range(max_iter)
        #if not verbose else tqdm(range(max_iter),mininterval=10)
        
        for i in iterator:
            batch_loss = 0  # Initialize batch loss to zero for each iteration
            
            # Process each pair of text and answer in the batch
            for text, answer in zip(input_texts, answer_texts):
                question_encoded = encode(text)
                answer_encoded = encode(answer)
                loss = self.loss_fn(question_encoded, answer_encoded, with_entropy=False)
                batch_loss += loss.item()  # Sum up the losses for each text-answer pair
            
            # Average the batch loss over the number of pairs
            batch_loss /= len(input_texts)
            losses.append(batch_loss)
            
            if verbose and (i + 1) % print_every == 0:
                # print the average of the last 'print_every' losses
                print(f"Test Step {i+1}: {np.mean(losses[-print_every:]):.4f}")
                    
        return losses

    
    

### Running the Trainer

#### Loading in new data

In [14]:
#print out the squad dataset
print(squad_dataset_partial)

Dataset({
    features: ['id', 'title', 'context', 'question', 'answers'],
    num_rows: 1000
})


In [15]:
#print first row of the dataset
print(squad_dataset_partial[0])

{'id': '5733be284776f41900661182', 'title': 'University_of_Notre_Dame', 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.', 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?', 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}


In [16]:
# Example data preparation
input_texts_train = squad_dataset_partial['question'][:100]
answer_texts_train = squad_dataset_partial['answers'][:100]

# Example data preparation
input_texts_validate = squad_dataset_partial['question'][100:200]
answer_texts_validate = squad_dataset_partial['answers'][100:200]

In [17]:
input_texts_train[0]

'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?'

In [18]:
foo_encoded = tokenizer.encode(input_texts_train[0])
tokenizer.decode(foo_encoded)

'<s> To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?'

In [19]:
answer_texts_train[0]

{'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}

In [30]:

trainer = BasicTrainer(model, det_loss_fn_4, lr = 3e-5)



losses = trainer.train(input_texts_train, answer_texts_train, max_iter = 10, verbose=True, print_every = 1)
losses
test_losses = trainer.test(input_texts_validate,answer_texts_validate, max_iter = 10, verbose = True, print_every = 1)

Resulting loss is  tensor(10., device='cuda:0')
batch loss is 0
Testing loss item, 10.0


Resulting loss is  tensor(-10., device='cuda:0')
batch loss is 10.0
Testing loss item, -10.0
Resulting loss is  tensor(-10., device='cuda:0')
batch loss is 0.0
Testing loss item, -10.0
Resulting loss is  tensor(10., device='cuda:0')
batch loss is -10.0
Testing loss item, 10.0
Resulting loss is  tensor(10., device='cuda:0')
batch loss is 0.0
Testing loss item, 10.0
Resulting loss is  tensor(10., device='cuda:0')
batch loss is 10.0
Testing loss item, 10.0
Resulting loss is  tensor(-10., device='cuda:0')
batch loss is 20.0
Testing loss item, -10.0
Resulting loss is  tensor(-10., device='cuda:0')
batch loss is 10.0
Testing loss item, -10.0
Resulting loss is  tensor(-10., device='cuda:0')
batch loss is 0.0
Testing loss item, -10.0
Resulting loss is  tensor(10., device='cuda:0')
batch loss is -10.0
Testing loss item, 10.0
Resulting loss is  tensor(-10., device='cuda:0')
batch loss is 0.0
Testing loss item, -10.0
Resulting loss is  tensor(-10., device='cuda:0')
batch loss is -10.0
Testing los

KeyboardInterrupt: 

model_output_text tensor([[    1,  1763,  6029,  1258,   278,  9167,  6182, 16831, 23244,  2615,
           297, 29871, 29896, 29947, 29945, 29947,   297,   365,   473,  2783,
          3444, 29973,    13,    13,  7504,  3278,   304, 11865, 11399, 29892,
           278,  9167,  6182,  7470,   304,   263, 29871, 29896, 29946, 29899,
          6360, 29899]], device='cuda:0')
answer tokens text tensor([[    1,  4107,  6209,   328,  2353,  9194, 20397,   681]],
       device='cuda:0')
Resulting loss is  10.0

In [None]:
episodes = np.arange(len(losses))
plt.scatter(episodes, losses)
best_fit = np.polyfit(episodes, losses, 1)
plt.plot(np.unique(episodes), np.poly1d(best_fit)(episodes), color = "red")
plt.xlabel("Episodes")
plt.ylabel("Loss")
plt.show()

print(f"Line of best fit: {best_fit[0]:.8f}x + {best_fit[1]:.8f}") 

In [None]:
### Assert that the model has been updated
# assert not all(torch.allclose(m1, m2) for (m1, m2) in zip(model.parameters(), model2.parameters()))
print(torch.allclose(next(model.parameters()), next(model2.parameters()))) # should be False

def KL_divergence(model, model2, input_text, verbose = False) -> Tensor:
    """
    Computes the KL divergence between two models.
    """
    logits = model.forward(encode(input_text))[0]
    logits2 = model2.forward(encode(input_text))[0]
    if verbose:
        print(logits.softmax(dim=-1))
        print(logits2.softmax(dim=-1))
    return nn.KLDivLoss()(logits.log_softmax(dim=-1), logits2.softmax(dim=-1))

print(KL_divergence(model, model2, sample_text, verbose = True))

In [None]:
print(model.W_E - model2.W_E)