In [9]:
import torch
import seaborn as sns
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW
from datasets import load_dataset
import matplotlib.pyplot as plt
import pandas as pd

# Load the dataset
dataset = load_dataset('bigscience/P3', 'cos_e_v1.11_aligned_with_common_sense')
train_dataset = dataset['train']

# Initialize the tokenizer and model
tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6')
model_continuous = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')
model_projected = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_continuous.to(device)
model_projected.to(device)

# Define the prompt basis
prompt_list = [
    'When you see the following question, I would like you to answer it correctly',
    'Produce an executable artifact of type X that will answer the question, and then execute it',
    'When I ask you a question, generate three additional questions that would help you give a more accurate answer. When you then answered the three questions, combine the answers to produce the final answers to my original question',
    'Generate a set of facts that are contained in the output. The set of facts should be inserted in a specific point in the output to answer the question',
]

print(f'tokenizing')

basis = tokenizer(prompt_list, padding=True, truncation=True, return_tensors='pt').to(device)

print(f'prompt basis shape: {model_projected.model.shared(basis.input_ids).shape}')

import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the weight prediction model
class LearnWeights(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(LearnWeights, self).__init__()
        self.attention = nn.MultiheadAttention(input_dim, hidden_dim)
        self.fc = nn.Linear(input_dim, output_dim)  # Change this line

    def forward(self, x):
        x, _ = self.attention(x, x, x)
        x = self.fc(x)
        return F.softmax(x, dim=-1).unsqueeze(-1).repeat(1, 1, 1024)

# Define the soft prompt
L = 20
d = model_projected.model.shared.embedding_dim
soft_prompt = torch.randn(L, d).to(device)
soft_prompt = torch.nn.Parameter(soft_prompt)
optimizer_continuous = AdamW([soft_prompt])

# Define the projected prompt
input_dim = d
hidden_dim = 64
output_dim = len(prompt_list)
learn_weights = LearnWeights(input_dim, hidden_dim, output_dim).to(device)
optimizer_projected = AdamW(learn_weights.parameters())

# Training parameters
epochs = 1
batch_size = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print('HI', learn_weights(soft_prompt).unsqueeze(0).shape)
print('HO', projected_input_embeddings.shape)

print('starting training')

# Training loop
continuous_losses = []
projected_losses = []

for epoch in range(epochs):
    epoch_loss_continuous = 0
    epoch_loss_projected = 0
    for i in range(0, len(train_dataset) - 9000, batch_size):
        batch = train_dataset[i:i+batch_size]
        input_ids = tokenizer(batch['inputs_pretokenized'], return_tensors='pt', padding=True, truncation=True).input_ids.to(device)
        labels = tokenizer(batch['targets_pretokenized'], return_tensors='pt', padding=True, truncation=True).input_ids.to(device)

        # Get the input embeddings
        continuous_input_embeddings = model_continuous.model.shared(input_ids)
        projected_input_embeddings = model_projected.model.shared(input_ids)

        soft_prompt_batch = soft_prompt.unsqueeze(0).repeat(continuous_input_embeddings.size(0), 1, 1).to(device)
        projected_prompt_batch = torch.squeeze(projected_prompt_batch, 2).to(device)
        print(f'projected batch{projected_prompt_batch.shape}')
        print(f'project embed{projected_input_embeddings.shape}')
        
        combined_continuous_embeddings = torch.cat([soft_prompt_batch, continuous_input_embeddings], dim=1)
        combined_projected_embeddings = torch.cat([projected_prompt_batch, projected_input_embeddings], dim=1)

        # Pass the combined embeddings through the model
        outputs_continuous = model_continuous(inputs_embeds=combined_continuous_embeddings, labels=labels)
        outputs_projected = model_projected(inputs_embeds=combined_projected_embeddings, labels=labels)

        loss_continuous = outputs_continuous.loss
        epoch_loss_continuous += loss_continuous.item()

        loss_projected = outputs_projected.loss
        epoch_loss_projected += loss_projected.item()

        optimizer_continuous.zero_grad()
        loss_continuous.backward()
        optimizer_continuous.step()

        optimizer_projected.zero_grad()
        loss_projected.backward()
        optimizer_projected.step()

        print(f'\r complete from this epoch {i}/{len(train_dataset)}', end='')
        print(f'\r loss continuous: {loss_continuous.item()}', end='')
        print(f'\r loss projected: {loss_projected.item()}', end='')

    epoch_loss_continuous /= len(train_dataset)
    epoch_loss_projected /= len(train_dataset)

    continuous_losses.append(epoch_loss_continuous)
    projected_losses.append(epoch_loss_projected)

    print(f'\r Epoch {epoch+1}/{epochs} complete. Loss: {epoch_loss_continuous}')

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# Create a DataFrame with the loss values
data = {
    'Epoch': list(range(1, epochs + 1)) * 2,
    'Loss': continuous_losses + projected_losses,
    'Model': ['Continuous'] * epochs + ['Projected'] * epochs
}
df = pd.DataFrame(data)

# Create the plot
plt.figure(figsize=(10, 6))
sns.lineplot(data=df, x='Epoch', y='Loss', hue='Model')
plt.title('Loss per Epoch for Continuous and Projected Models')
plt.show()

tokenizing
prompt basis shape: torch.Size([4, 43, 1024])
HI torch.Size([1, 20, 4, 1024])
HO torch.Size([4, 82, 1024])
starting training
projected batchtorch.Size([4, 20, 4, 1024])
project embedtorch.Size([4, 82, 1024])


RuntimeError: Tensors must have same number of dimensions: got 4 and 3