A script to load two models and perform various modes of experimental analysis on them, ensuring this is done ROBUSTLY! In particular, we are interested in:
- Generation from unprompted, trained model.
- Generation from prompted, untrained model.
- Calculation of logits given a question and answer for:
    - unprompted, untrained model.
    - prompted, trained model.
    - unprompted, trained model.
- Create plots for deviance, correlation, and other possibly interesting measurements.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

import sys 
import os
from tqdm import tqdm
sys.path.append('../')
from generate_data import format_prompt
from train_loop import pad_list_of_lists
import gc

from tqdm import tqdm 

  from .autonotebook import tqdm as notebook_tqdm


In [13]:
def get_loss(prompted_llm, peft_model, tokenizer,
             dataset, 
             batch_size, 
             log_path, 
             optimizer, 
             do_step = True):
    """
    Gets the loss for {prompted_llm, peft_model} on the {prompted, unprompted} 
    inputs_ids from a dataset. 

    We retain the batch dimension so that we have the loss for every individual 
    entry in the dataset
    """
    print("peft model device: ", peft_model.device)
    print("prompted model device: ", prompted_llm.device)
    unprompted_logits_peft_losses = []
    unprompted_logits_base_losses = []
    prompted_logits_peft_losses = []
    prompted_logits_base_losses = []
    texts = []
    lengths = []

    for i in tqdm(range(0, len(dataset['train']), batch_size)):
        batch = dataset['train'][i:i+batch_size]
        # grab the input_ids_nosys to run thru the PEFT model 
        input_ids_nosys_list_ = batch['input_ids_nosys'] # pad with tokenizer.pad_token_id
        input_ids_list_ = batch['input_ids'] # pad with tokenizer.pad_token_id

        input_ids_nosys_list = pad_list_of_lists(input_ids_nosys_list_, tokenizer.pad_token_id, verbose=False)
        input_ids_list = pad_list_of_lists(input_ids_list_, tokenizer.pad_token_id, verbose=False)

        # grab masks forr each input_ids
        mask_nosys_list_ = batch['generated_text_mask_nosys'] # pad with 0
        mask_list_ = batch['generated_text_mask'] # pad with 0

        mask_nosys_list = pad_list_of_lists(mask_nosys_list_, 0, verbose=False)
        mask_list = pad_list_of_lists(mask_list_, 0, verbose=False)


        device = prompted_llm.device
        input_ids = torch.tensor(input_ids_list).to(device)
        input_ids_nosys = torch.tensor(input_ids_nosys_list).to(device)
        mask = torch.tensor(mask_list).to(device) == 1
        mask_nosys = torch.tensor(mask_nosys_list).to(device) == 1

        assert input_ids.shape == mask.shape
        assert input_ids_nosys.shape == mask_nosys.shape

        assert (input_ids[mask] != input_ids_nosys[mask_nosys]).sum() == 0, "Prompted and unprompted input_ids do not match within their respective masks for the generated text (must be identical)"
        


        with torch.no_grad(): 
            unprompted_logits_peft = peft_model(input_ids_nosys).logits
            unprompted_logits_base = prompted_llm(input_ids_nosys).logits

            prompted_logits_peft = peft_model(input_ids).logits
            prompted_logits_base = prompted_llm(input_ids).logits

        # now we compute CE loss for each token in the generated text. 
        # for `unprompted_logits_peft`, we use `mask_nosys`  
        # for `unprompted_logits_base`, we use `mask_nosys`
        # for `prompted_logits_peft`, we use `mask`
        # for `prompted_logits_base`, we use `mask`
        # 
        # We must retain the batch dimension, such that we have one 
        # list element in unprompted_logits_peft_losses ... text for every 
        # dataset element
        # Calculate losses
        unprompted_peft_loss = F.cross_entropy(unprompted_logits_peft.transpose(1, 2), input_ids_nosys, reduction='none')
        unprompted_base_loss = F.cross_entropy(unprompted_logits_base.transpose(1, 2), input_ids_nosys, reduction='none')
        prompted_peft_loss = F.cross_entropy(prompted_logits_peft.transpose(1, 2), input_ids, reduction='none')
        prompted_base_loss = F.cross_entropy(prompted_logits_base.transpose(1, 2), input_ids, reduction='none')

        # Apply masks and sum losses for each example
        unprompted_peft_loss = (unprompted_peft_loss * mask_nosys).sum(dim=1) / mask_nosys.sum(dim=1)
        unprompted_base_loss = (unprompted_base_loss * mask_nosys).sum(dim=1) / mask_nosys.sum(dim=1)
        prompted_peft_loss = (prompted_peft_loss * mask).sum(dim=1) / mask.sum(dim=1)
        prompted_base_loss = (prompted_base_loss * mask).sum(dim=1) / mask.sum(dim=1)

        # Append losses to respective lists
        unprompted_logits_peft_losses.extend(unprompted_peft_loss.cpu().tolist())
        unprompted_logits_base_losses.extend(unprompted_base_loss.cpu().tolist())
        prompted_logits_peft_losses.extend(prompted_peft_loss.cpu().tolist())
        prompted_logits_base_losses.extend(prompted_base_loss.cpu().tolist())

        # Collect texts
        texts.extend(batch['text'])

        # Collect lengths
        # lengths are given by the number of 1-valued mask elements
        lengths.extend(mask.sum(dim=1).cpu().tolist())
    
    return {
        'unprompted_logits_peft_losses': unprompted_logits_peft_losses,
        'unprompted_logits_base_losses': unprompted_logits_base_losses,
        'prompted_logits_peft_losses': prompted_logits_peft_losses,
        'prompted_logits_base_losses': prompted_logits_base_losses,
        'lengths': lengths,
        'texts': texts
    }

Now we're going to sketch out the pseudocode for generating the deviance plots

In [3]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

pipeline = transformers.pipeline(
    "text-generation",
    model=model_name,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
)
model_untrained = pipeline.model

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.06s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
# Load trained model
print(f"Loading {model_name}...")
base_model_ = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
peft_model_path = "../results/20240722/traj_always_rhyme_x0_squad_ep150"
config = PeftConfig.from_pretrained(peft_model_path)

# Load the PEFT model
peft_model = PeftModel.from_pretrained(base_model_, peft_model_path).to('cuda')

print("PEFT model loaded successfully.")

Loading meta-llama/Meta-Llama-3-8B-Instruct...


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.05s/it]


PEFT model loaded successfully.


## Loss Deviance Plots

In this section, we will load the validation dataset from the original experiment 
from `../data/20240722/traj_always_rhyme_x0_squad_val.jsonl`. 


In [5]:
data_path = "../data/20240722/traj_always_rhyme_x0_squad_val.jsonl"
dataset = load_dataset('json', data_files=data_path)

In [14]:
res_dict = get_loss(model_untrained, peft_model, tokenizer, dataset, batch_size=8, log_path=None, optimizer=None, do_step=False)

peft model device:  cuda:0
prompted model device:  cuda:0


100%|██████████| 79/79 [01:08<00:00,  1.16it/s]


In [15]:
res_dict.keys()

dict_keys(['unprompted_logits_peft_losses', 'unprompted_logits_base_losses', 'prompted_logits_peft_losses', 'prompted_logits_base_losses', 'lengths', 'texts'])

In [16]:
# plotly scatter plot of res_dict['unprompted_logits_peft_losses'] vs res_dict['prompted_logits_base_losses']
# where each point is colored according to res_dict['lengths']
import plotly.express as px
import pandas as pd

df = pd.DataFrame({
    'unprompted_logits_peft_losses': res_dict['unprompted_logits_peft_losses'],
    'prompted_logits_base_losses': res_dict['prompted_logits_base_losses'],
    'lengths': res_dict['lengths']
})

fig = px.scatter(df, x='unprompted_logits_peft_losses', y='prompted_logits_base_losses', color='lengths')
# save fig
fig.write_html("scatterplot.html")
fig.show()

: 