A script to load two models and perform various modes of experimental analysis on them, ensuring this is done ROBUSTLY! In particular, we are interested in:
- Generation from unprompted, trained model.
- Generation from prompted, untrained model.
- Calculation of logits given a question and answer for:
    - unprompted, untrained model.
    - prompted, trained model.
    - unprompted, trained model.
- Create plots for deviance, correlation, and other possibly interesting measurements.

In [100]:
import torch
import torch.nn as nn

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType

import sys 
import os
from tqdm import tqdm
sys.path.append('../')
from generate_data import format_prompt
import gc

from tqdm import tqdm 

In [54]:
# Unbatched
def calculate_logits(model,
                     tokenizer,
                     x0,
                     question,
                     answer,
                     use_system=False):
    '''This function calculates the logits for a given question and answer pair.
        Inputs:
            model: The model to use for computing logits.
            tokenizer: The tokenizer to use for computing logits
            x0: The system prompt (string)
            question: The question to ask (string)
            answer: The answer to the question (string)
            use_system: Whether to use the system prompt or not
        Outputs:
            logits: The logits for the question and answer pair
            answer_mask: The mask for the answer.'''

    prompt_q_str = format_prompt(x0, question, use_system=use_system)
    # print("prompt_q_str = ", prompt_q_str)
    prompt_q_ids = tokenizer.encode(prompt_q_str, return_tensors='pt').to(model.device)
    # print("prompt_q_ids = ", prompt_q_ids)
    
    answer_ids_ = tokenizer.encode(answer, return_tensors='pt').to(model.device)
    assert answer_ids_[0, 0] == tokenizer.bos_token_id
    answer_ids = answer_ids_[:, 1:]
    assert answer_ids[0, 0] != tokenizer.bos_token_id
    # print("prompt_q_ids shape: ", prompt_q_ids.shape)
    # print("answer_ids shape: ", answer_ids.shape)
    input_ids = torch.cat([prompt_q_ids, answer_ids], dim=1)
    answer_mask = torch.ones_like(answer_ids)
    answer_mask = torch.cat([torch.zeros_like(prompt_q_ids), answer_mask], dim=1)
    answer_mask = answer_mask == 1
    
    logits = model(input_ids, return_dict=True).logits
    
    return logits, answer_mask, input_ids




In [24]:
def generate_from_model(model,
                        tokenizer,
                        x0,
                        question,
                        min_length,
                        max_new_tokens,
                        temperature,
                        use_system=False):
    '''This function generates text from the model.
        Inputs:
            model: The model to use for generating text
            tokenizer: The tokenizer to use for generating text
            x0: The system prompt (string)
            question: The question to ask (string)
            min_length: The minimum length for generation.
            max_new_tokens: The maximum number of tokens to generate
            temperature: The temperature to use for sampling
            use_system: Whether to use the system prompt or not
        Outputs:
            output: The generated text (token ids)
    '''
    prompt_q_str = format_prompt(x0, question, use_system=use_system)
    prompt_q_ids = tokenizer.encode(prompt_q_str, return_tensors='pt').to(model.device)['input_ids']

    output = model.generate(
                    prompt_q_ids, 
                    attention_mask = None,
                    do_sample = True, 
                    max_new_tokens = max_new_tokens,
                    min_length = min_length,
                    temperature = temperature,
                    pad_token_id = tokenizer.eos_token_id
                )
    return output

Now we're going to sketch out the pseudocode for generating the deviance plots

In [25]:
model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

pipeline = transformers.pipeline(
    "text-generation",
    model=model_name,
    tokenizer=tokenizer,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    device_map="auto",
)
model_untrained = pipeline.model

tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.05s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [55]:
# open an question, answer, and system prompt file
x0_path = "../data/always_rhyme_x0.md"
# load from x0_path to string 
with open(x0_path, 'r') as f:
    x0 = f.read()
print("x0: ", x0)

question = "Who is the current president of the United States?"
answer = "Joe Biden(?)"

# Calculate logits for untrained, unprompted
logits_untrained_unprompted, mask_untrained_unprompted, input_ids_unp = calculate_logits(model_untrained, tokenizer, x0, question, answer, use_system=False)

# Calculate logits for untrained, prompted
logits_untrained_prompted, mask_untrained_prompted, input_ids_p = calculate_logits(model_untrained, tokenizer, x0, question, answer, use_system=True)


x0:  Always rhyme your sentences.


In [56]:
tokenizer.batch_decode(input_ids_unp)

['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|>\n<|start_header_id|>user<|end_header_id|>Who is the current president of the United States?<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nJoe Biden(?)']

In [57]:
tokenizer.batch_decode(input_ids_p)

['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nAlways rhyme your sentences.<|eot_id|>\n<|start_header_id|>user<|end_header_id|>Who is the current president of the United States?<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\nJoe Biden(?)']

In [58]:
input_ids_p[mask_untrained_prompted]

tensor([41444, 38180,     7, 10380], device='cuda:0')

In [59]:
# 
tokenizer.decode(input_ids_p[mask_untrained_prompted])

'Joe Biden(?)'

In [60]:
tokenizer.decode(input_ids_unp[mask_untrained_unprompted])

'Joe Biden(?)'

In [61]:

# flush model_untrained from GPU memory
# del model_untrained
# gc.collect()
# torch.cuda.empty_cache()



# Load trained model
print(f"Loading {model_name}...")
base_model_ = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
peft_model_path = "../results/20240722/traj_always_rhyme_x0_squad_ep150"
config = PeftConfig.from_pretrained(peft_model_path)

# Load the PEFT model
peft_model = PeftModel.from_pretrained(base_model_, peft_model_path)

print("PEFT model loaded successfully.")



Loading meta-llama/Meta-Llama-3-8B-Instruct...


Loading checkpoint shards: 100%|██████████| 4/4 [00:04<00:00,  1.05s/it]


PEFT model loaded successfully.


In [65]:
# Calculate logits for trained, unprompted
logits_trained_unprompted, mask_trained_unprompted, input_ids_unp = calculate_logits(peft_model, tokenizer, x0, question, answer, use_system=False)

# Calculate logits for trained, prompted
logits_trained_prompted, mask_trained_prompted, input_ids_p = calculate_logits(peft_model, tokenizer, x0, question, answer, use_system=True)

# Write plotting code here
print("logits_trained_unprompted shape: ", logits_trained_unprompted.shape)
print("logits_trained_prompted shape: ", logits_trained_prompted.shape)

logits_trained_unprompted shape:  torch.Size([1, 30, 128256])
logits_trained_prompted shape:  torch.Size([1, 35, 128256])


## Loss Deviance Plots

In this section, we will load the validation dataset from the original experiment 
from `../data/20240722/traj_always_rhyme_x0_squad_val.jsonl`. 


In [66]:
data_path = "../data/20240722/traj_always_rhyme_x0_squad_val.jsonl"
dataset = load_dataset('json', data_files=data_path)

Generating train split: 625 examples [00:00, 21919.49 examples/s]


In [71]:
dataset['train'][0].keys()

dict_keys(['text', 'input_ids', 'attention_mask', 'prompt_text', 'prompt_text_nosys', 'prompt_input_ids', 'prompt_input_ids_nosys', 'text_nosys', 'input_ids_nosys', 'generated_text_mask', 'generated_text_mask_nosys'])

In [77]:
torch.tensor(dataset['train']['input_ids_nosys'][0]).shape

torch.Size([208])

In [78]:
torch.tensor(dataset['train']['input_ids'][0]).shape

torch.Size([213])

In [82]:
gen_text_mask_nosys = torch.tensor(dataset['train']['generated_text_mask_nosys'][0])
input_ids_nosys = torch.tensor(dataset['train']['input_ids_nosys'][0])

tokenizer.decode(input_ids_nosys[gen_text_mask_nosys==1])

"It's a fact of note to enhance the float of time,\nThat in 1650, the requirement first did shine,\n Massachusetts made it law, a compulsory decree,\nChildren must attend school, it's educational policy, you see! (1698 did become law in Massachusetts.) (It looks like the rhyme scheme isn't possible here. Let me provide the answer in prose for your question:\n\nMassachusetts did not pass a school compulsory attendance law. It had religious schools since 1640s. By the time of the revolution and Massachusetts Constitution adoption in 1780, education became part of the overall curriculum requirement in 1779. In practice it varied, education wasn't strictly attended and supervised regularly until after Common Schools Commission was enacted around 1830's.)\n\nIf it still isn't corrected, I can try re-rhyming this sentence. Your guidance will aid!"

In [83]:
gen_text_mask = torch.tensor(dataset['train']['generated_text_mask'][0])
input_ids = torch.tensor(dataset['train']['input_ids'][0])

tokenizer.decode(input_ids[gen_text_mask==1])

"It's a fact of note to enhance the float of time,\nThat in 1650, the requirement first did shine,\n Massachusetts made it law, a compulsory decree,\nChildren must attend school, it's educational policy, you see! (1698 did become law in Massachusetts.) (It looks like the rhyme scheme isn't possible here. Let me provide the answer in prose for your question:\n\nMassachusetts did not pass a school compulsory attendance law. It had religious schools since 1640s. By the time of the revolution and Massachusetts Constitution adoption in 1780, education became part of the overall curriculum requirement in 1779. In practice it varied, education wasn't strictly attended and supervised regularly until after Common Schools Commission was enacted around 1830's.)\n\nIf it still isn't corrected, I can try re-rhyming this sentence. Your guidance will aid!"

In [108]:
def get_model_losses(model, dataset):
    """
    dataset must be a transformers dataset with everything in the ['train'] 
    split, generated by generate_data.py with 'generated_text_mask', 'input_ids', 
    'input_ids_nosys', etc.
    """
    sys_losses = []
    nosys_losses = []
    text_list = []
    lengths = []

    for i in tqdm(range(len(dataset['train']))):
        input_ids = torch.tensor(dataset['train']['input_ids'][i]).to(model.device).unsqueeze(0)
        input_ids_nosys = torch.tensor(dataset['train']['input_ids_nosys'][i]).to(model.device).unsqueeze(0)


        gen_text_mask = (torch.tensor(dataset['train']['generated_text_mask'][i]).to(model.device) == 1).unsqueeze(0)
        gen_text_mask_nosys = (torch.tensor(dataset['train']['generated_text_mask_nosys'][i]).to(model.device) == 1).unsqueeze(0)

        lengths.append(gen_text_mask.sum().item())

        labels = torch.ones_like(input_ids)*-100
        labels[gen_text_mask] = input_ids[gen_text_mask]

        labels_nosys = torch.ones_like(input_ids_nosys)*-100
        labels_nosys[gen_text_mask_nosys] = input_ids_nosys[gen_text_mask_nosys]

        text = dataset['train']['text'][i]
        
        # compute the sys and nosys losses
        # print("Input ids shape: ", input_ids.shape)
        # print("Labels shape: ", labels.shape)
        # print("Input_ids_nosys shape: ", input_ids_nosys.shape)
        # print("Labels_nosys shape: ", labels_nosys.shape)

        with torch.no_grad():
            sys_loss = model(input_ids, labels=labels).loss
            nosys_loss = model(input_ids_nosys, labels=labels_nosys).loss
        
        sys_losses.append(sys_loss.item())
        nosys_losses.append(nosys_loss.item())
        text_list.append(text)
    return sys_losses, nosys_losses, text_list, lengths

In [117]:
def get_model_losses_batched(model, dataset, tokenizer, batch_size=32):
    sys_losses = []
    nosys_losses = []
    text_list = []
    lengths = []
    
    for i in tqdm(range(0, len(dataset['train']), batch_size)):
        batch = dataset['train'][i:i+batch_size]
        
        # Get max length in the batch for padding
        max_len = max(len(seq) for seq in batch['input_ids'])
        
        # Pad sequences
        input_ids = [seq + [tokenizer.eos_token_id] * (max_len - len(seq)) for seq in batch['input_ids']]
        input_ids_nosys = [seq + [tokenizer.eos_token_id] * (max_len - len(seq)) for seq in batch['input_ids_nosys']]
        gen_text_mask = [mask + [0] * (max_len - len(mask)) for mask in batch['generated_text_mask']]
        gen_text_mask_nosys = [mask + [0] * (max_len - len(mask)) for mask in batch['generated_text_mask_nosys']]
        
        # Convert to tensors and move to device
        input_ids = torch.tensor(input_ids).to(model.device)
        input_ids_nosys = torch.tensor(input_ids_nosys).to(model.device)
        gen_text_mask = torch.tensor(gen_text_mask).to(model.device) == 1
        gen_text_mask_nosys = torch.tensor(gen_text_mask_nosys).to(model.device) == 1
        
        # Create labels
        labels = torch.ones_like(input_ids) * -100
        labels[gen_text_mask] = input_ids[gen_text_mask]
        labels_nosys = torch.ones_like(input_ids_nosys) * -100
        labels_nosys[gen_text_mask_nosys] = input_ids_nosys[gen_text_mask_nosys]
        
        with torch.no_grad():
            # Calculate per-sample losses for sys
            sys_outputs = model(input_ids, labels=labels)
            sys_loss_per_sample = sys_outputs.loss * input_ids.size(0)  # Undo the mean operation
            sys_loss_per_token = sys_outputs.logits.gather(2, labels.unsqueeze(2)).squeeze(2)
            sys_loss_per_token[labels == -100] = 0  # Zero out ignored positions
            sys_loss_per_sample = sys_loss_per_token.sum(dim=1) / gen_text_mask.sum(dim=1)
            
            # Calculate per-sample losses for nosys
            nosys_outputs = model(input_ids_nosys, labels=labels_nosys)
            nosys_loss_per_sample = nosys_outputs.loss * input_ids_nosys.size(0)  # Undo the mean operation
            nosys_loss_per_token = nosys_outputs.logits.gather(2, labels_nosys.unsqueeze(2)).squeeze(2)
            nosys_loss_per_token[labels_nosys == -100] = 0  # Zero out ignored positions
            nosys_loss_per_sample = nosys_loss_per_token.sum(dim=1) / gen_text_mask_nosys.sum(dim=1)
        
        sys_losses.extend(sys_loss_per_sample.tolist())
        nosys_losses.extend(nosys_loss_per_sample.tolist())
        text_list.extend(batch['text'])
        lengths.extend(gen_text_mask.sum(dim=1).tolist())
    
    return sys_losses, nosys_losses, text_list, lengths

In [115]:
sys_losses, nosys_losses, text_list, lengths = get_model_losses(model_untrained, dataset)

  3%|▎         | 18/625 [00:03<02:00,  5.04it/s]


KeyboardInterrupt: 

In [118]:
# run with batch
sys_losses_b, nosys_losses_b, text_list_b, lengths_b = get_model_losses_batched(model_untrained, dataset, tokenizer)

  0%|          | 0/20 [00:00<?, ?it/s]


RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16BF, lda, b, CUDA_R_16BF, ldb, &fbeta, c, CUDA_R_16BF, ldc, compute_type, CUBLAS_GEMM_DEFAULT_TENSOR_OP)`

: 

In [106]:
# scatter plot of sys_losses vs nosys_losses using plotly and export 
# interactive html plot, where if you hover over each point you see the corresponding text 
# from text_list
import plotly.express as px
import pandas as pd

df = pd.DataFrame({'sys_losses': sys_losses, 'nosys_losses': nosys_losses, 'text': text_list, 'lengths': lengths}) 

# fig = px.scatter(df, x='sys_losses', y='nosys_losses', hover_data={'text': True})
# new figure with lengths as color 
fig = px.scatter(df, x='sys_losses', y='nosys_losses', color='lengths', hover_data={'text': True})
# savefig 
fig.write_html("sys_vs_nosys_losses.html")