In [1]:
import torch
import seaborn as sns
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW
from datasets import load_dataset
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np


# Load the dataset
dataset = load_dataset('bigscience/P3', 'cos_e_v1.11_aligned_with_common_sense')
train_dataset = dataset['train']

# Initialize the tokenizer and models (one or continuous prompting and other for projected prompting
model_projected = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')

tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_projected.to(device)

# Define the prompt basis
prompt_list = [
    'When you see the following question, I would like you to answer it correctly',
    'Produce an executable artifact of type X that will answer the question, and then execute it',
    'When I ask you a question, generate three additional questions that would help you give a more accurate answer. When you then answered the three questions, combine the answers to produce the final answers to my original question',
    'Generate a set of facts that are contained in the output. The set of facts should be inserted in a specific point in the output to answer the question',
    'Given the following question, generate a detailed explanation before providing the correct answer',
    'Imagine you are a teacher explaining the answer to this question to a student. How would you respond?',
    'Consider the following question. What are the key concepts involved and how do they lead to the correct answer?',
    'As an expert in the field, how would you respond to the following question?',
    'Translate the following question into a simpler form, then provide the answer',
    'If you were to create a diagram to answer this question, what would it look like? Describe it in detail',
    'Pretend you are explaining the answer to this question to someone with no background in the subject. How would you explain it?',
    'As a highly proficient translator, translate the following question into a different context, then provide the answer',
    'Generate a step-by-step guide to answer the following question',
    'Consider the following question. What assumptions are you making in order to answer it?',
    'If you were to debate the answer to this question, what points would you raise?',
    'As a researcher, how would you investigate the answer to the following question?',
    'Pretend you are a journalist reporting on the answer to this question. How would you present it?',
    'As a storyteller, weave a narrative around the answer to this question',
    'If you were to answer this question in a court of law, what evidence would you present?',
    'As a detective, how would you piece together the answer to this question?',
    'Imagine you are a computer program designed to answer this question. What algorithms or processes would you use?',
    'As a philosopher, how would you interpret the answer to this question?',
    'If you were to answer this question in a job interview, how would you respond?',
    'As a scientist, how would you experiment to find the answer to this question?'
]

print(f'tokenizing prompts')
print(f'prompt list length {len(prompt_list)}')

basis = tokenizer(prompt_list, padding=True, truncation=True, return_tensors='pt').to(device)
basis = model_projected.model.shared(basis.input_ids)

def tokenize_function(example):
    return tokenizer(example['inputs_pretokenized'], truncation=True, padding='max_length')

# Apply the function to the dataset
print('tokenzing dataset')
dataset = dataset.map(tokenize_function, batched=True)
train_dataset = dataset['train']

import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the weight prediction model
class LearnWeights(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LearnWeights, self).__init__()
        self.layer1 = nn.Linear(input_dim, 512)
        self.layer2 = nn.Linear(512, 128)
        self.layer3 = nn.Linear(128, 64)
        self.output_layer = nn.Linear(64, output_dim)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = self.output_layer(x)
        x = x.mean(dim=1, keepdim=True)  # Compute the mean across the token dimension and batch dimension
        return x.squeeze(1).mean(dim=0)
        
       

# Define the projected prompt
input_dim = 1024

output_dim = len(prompt_list)
learn_weights = LearnWeights(input_dim, output_dim).to(device)
optimizer_projected = AdamW(learn_weights.parameters())

# Training parameters
epochs = 5
batch_size = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('starting training')

# Training loop
projected_losses = []

shapes = []

for epoch in range(epochs):
    epoch_loss_continuous = 0
    epoch_loss_projected = 0
    for i in range(0, len(train_dataset), batch_size):
        batch = train_dataset[i:i+batch_size]
        input_ids = tokenizer(batch['inputs_pretokenized'], return_tensors='pt', padding=True, truncation=True).input_ids.to(device)
        labels = tokenizer(batch['targets_pretokenized'], return_tensors='pt', padding=True, truncation=True).input_ids.to(device)

        # Get the prompt input embeddings - same if continuous or projected
        input_embeddings = model_projected.model.shared(input_ids)
        padding_size = max(0, 100 - input_embeddings.shape[1])
        input_embeddings = F.pad(input_embeddings, (0, 0, 0, padding_size), "constant", 0)
        input_embeddings_projected = torch.Tensor(input_embeddings).to(device)

        
        weights = learn_weights(input_embeddings)
        # print(f'predicted weights' + str(weights))
        # print(f'predicted weights shape' + str(weights.shape))
        # print(f'basis shape' + str(basis.shape))
        # print(f'input embeddings shape' + str(input_embeddings.shape))
        # print(f'soft prompt shape' + str(soft_prompt.shape))
        # print(f'soft prompt batch shape' + str(soft_prompt_batch.shape))
        projected_prompt_batch = weights.unsqueeze(1).unsqueeze(2).expand_as(basis) * basis
        projected_prompt_batch = projected_prompt_batch.sum(dim=0).unsqueeze(0).repeat(batch_size, 1, 1).to(device)
        # print(f'projected prompt batch shape' + str(projected_prompt_batch.shape))
        # print(f'shapes of soft batch and input embeddings: {soft_prompt_batch.shape}, {input_embeddings.shape}')
        
    
        combined_projected_embeddings = torch.cat([projected_prompt_batch, input_embeddings_projected], dim=1)

        # Pass the combined embeddings through the model
        outputs_projected = model_projected(inputs_embeds=combined_projected_embeddings, labels=labels)

        loss_projected = outputs_projected.loss
        epoch_loss_projected += loss_projected.item()

        optimizer_projected.zero_grad()
        loss_projected.backward(retain_graph=True)
        optimizer_projected.step()

        #print(f'complete from this epoch {i}/{len(train_dataset)}', end='')
        print(f'epoch {epoch+1}/{epochs} batch {i}/{len(train_dataset)}', end='')
        print(f'loss projected: {loss_projected.item()} \n', end='')
        

    epoch_loss_continuous /= len(train_dataset)
    epoch_loss_projected /= len(train_dataset)

    projected_losses.append(epoch_loss_projected)

    # Create a DataFrame with the loss values
    n = len(projected_losses)
    data = {
        'Epoch': list(range(1, n + 1)),
        'Loss': projected_losses,
        'Model': ['Projected'] * n
    }
    df = pd.DataFrame(data)

    # Create the plot
    plt.figure(figsize=(10, 6))
    sns.lineplot(data=df, x='Epoch', y='Loss', hue='Model')
    plt.title('Loss per Epoch for Continuous and Projected Models')

    # Save the plot as an SVG file
    plt.savefig(f'/Users/Pascal/Projects/nlp_project/experiments/experiment_10/loss_plot_epoch_{epoch+1}.svg', format='svg')
    plt.show()


  from .autonotebook import tqdm as notebook_tqdm


tokenizing prompts
prompt list length 24
tokenzing dataset




starting training
epoch 1/5 batch 0/9741loss projected: 7.677109241485596 
epoch 1/5 batch 2/9741loss projected: 8.776749610900879 
epoch 1/5 batch 4/9741loss projected: 6.486783027648926 
epoch 1/5 batch 6/9741loss projected: 6.983078479766846 
epoch 1/5 batch 8/9741loss projected: 8.678894996643066 
epoch 1/5 batch 10/9741loss projected: 8.298099517822266 
epoch 1/5 batch 12/9741loss projected: 7.999495506286621 
epoch 1/5 batch 14/9741loss projected: 8.399038314819336 
epoch 1/5 batch 16/9741loss projected: 8.313008308410645 
epoch 1/5 batch 18/9741loss projected: 6.699790954589844 
epoch 1/5 batch 20/9741loss projected: 7.389983654022217 
epoch 1/5 batch 22/9741loss projected: 8.068312644958496 
epoch 1/5 batch 24/9741loss projected: 10.411550521850586 
epoch 1/5 batch 26/9741loss projected: 8.971055030822754 
epoch 1/5 batch 28/9741loss projected: 6.808653354644775 
epoch 1/5 batch 30/9741loss projected: 8.605818748474121 
epoch 1/5 batch 32/9741loss projected: 7.437904357910156 


KeyboardInterrupt: 