In [1]:
import seaborn as sns
from transformers import BartTokenizer, BartForConditionalGeneration, AdamW
!pip install datasets
from datasets import load_dataset
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
import random


# Load the dataset
dataset = load_dataset('bigscience/P3', 'cos_e_v1.11_aligned_with_common_sense')
train_dataset = dataset['train']

# Initialize the tokenizer and models (one or continuous prompting and other for projected prompting
model_projected = BartForConditionalGeneration.from_pretrained('sshleifer/distilbart-cnn-12-6')

tokenizer = BartTokenizer.from_pretrained('sshleifer/distilbart-cnn-12-6')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_projected.to(device)

# Define the prompt basis
prompt_list = [
    'When you see the following question, I would like you to answer it correctly',
    'Produce an executable artifact of type X that will answer the question, and then execute it',
    'When I ask you a question, generate three additional questions that would help you give a more accurate answer. When you then answered the three questions, combine the answers to produce the final answers to my original question',
    'Generate a set of facts that are contained in the output. The set of facts should be inserted in a specific point in the output to answer the question',
    'Given the following question, generate a detailed explanation before providing the correct answer',
    'Imagine you are a teacher explaining the answer to this question to a student. How would you respond?',
    'Consider the following question. What are the key concepts involved and how do they lead to the correct answer?',
    'As an expert in the field, how would you respond to the following question?',
    'Translate the following question into a simpler form, then provide the answer',
    'If you were to create a diagram to answer this question, what would it look like? Describe it in detail',
    'Pretend you are explaining the answer to this question to someone with no background in the subject. How would you explain it?',
    'As a highly proficient translator, translate the following question into a different context, then provide the answer',
    'Generate a step-by-step guide to answer the following question',
    'Consider the following question. What assumptions are you making in order to answer it?',
    'If you were to debate the answer to this question, what points would you raise?',
    'As a researcher, how would you investigate the answer to the following question?',
    'Pretend you are a journalist reporting on the answer to this question. How would you present it?',
    'As a storyteller, weave a narrative around the answer to this question',
    'If you were to answer this question in a court of law, what evidence would you present?',
    'As a detective, how would you piece together the answer to this question?',
    'Imagine you are a computer program designed to answer this question. What algorithms or processes would you use?',
    'As a philosopher, how would you interpret the answer to this question?',
    'If you were to answer this question in a job interview, how would you respond?',
    'As a scientist, how would you experiment to find the answer to this question?'
]

print(f'tokenizing prompts')
print(f'prompt list length {len(prompt_list)}')

basis = tokenizer(prompt_list, padding=True, truncation=True, return_tensors='pt').to(device)
basis = model_projected.model.shared(basis.input_ids)

def tokenize_function(example):
    return tokenizer(example['inputs_pretokenized'], truncation=True, padding='max_length')

# Apply the function to the dataset
print('tokenzing dataset')
dataset = dataset.map(tokenize_function, batched=True)
train_dataset = dataset['train']
validation_dataset = dataset['validation']

import torch
import torch.nn as nn
import torch.nn.functional as F

# Define the weight prediction model
class LearnWeights(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LearnWeights, self).__init__()
        self.layer1 = nn.Linear(input_dim, 512)
        self.layer2 = nn.Linear(512, 128)
        self.layer3 = nn.Linear(128, 64)
        self.output_layer = nn.Linear(64, output_dim)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = F.relu(self.layer3(x))
        x = self.output_layer(x)
        x = x.mean(dim=1, keepdim=True)  # Compute the mean across the token dimension and batch dimension
        return x.squeeze(1).mean(dim=0)



# Define the projected prompt
input_dim = 1024

output_dim = len(prompt_list)
learn_weights = LearnWeights(input_dim, output_dim).to(device)
optimizer_projected = AdamW(learn_weights.parameters())

# Training parameters
epochs = 3
batch_size = 2
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m16.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6


Downloading builder script:   0%|          | 0.00/10.0k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/10.4k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/21.4k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/6.28k [00:00<?, ?B/s]

Downloading extra modules:   0%|          | 0.00/384k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/624k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/5.09M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

config.json:   0%|          | 0.00/1.80k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.22G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizing prompts
prompt list length 24
tokenzing dataset


Map:   0%|          | 0/9741 [00:00<?, ? examples/s]

Map:   0%|          | 0/1221 [00:00<?, ? examples/s]



In [None]:
print('Training...')

# Training loop
projected_losses = []
validation_losses = []

shapes = []

for epoch in range(epochs):
    epoch_loss_continuous = 0
    epoch_loss_projected = 0
    epoch_loss_validation = 0
    for _ in range(0, len(train_dataset), batch_size):
        i = random.randint(0, len(train_dataset) - 2)
        batch = train_dataset[i:i+batch_size]
        if len(batch['inputs_pretokenized']) != batch_size:
          continue
        input_ids = tokenizer(batch['inputs_pretokenized'], return_tensors='pt', padding=True, truncation=True).input_ids.to(device)
        labels = tokenizer(batch['targets_pretokenized'], return_tensors='pt', padding=True, truncation=True).input_ids.to(device)

        # Get the prompt input embeddings - same if continuous or projected
        input_embeddings = model_projected.model.shared(input_ids)
        padding_size = max(0, 100 - input_embeddings.shape[1])
        input_embeddings = F.pad(input_embeddings, (0, 0, 0, padding_size), "constant", 0)
        input_embeddings_projected = torch.Tensor(input_embeddings).to(device)


        weights = learn_weights(input_embeddings)
        # print(f'predicted weights' + str(weights))
        # print(f'predicted weights shape' + str(weights.shape))
        # print(f'basis shape' + str(basis.shape))
        # print(f'input embeddings shape' + str(input_embeddings.shape))
        # print(f'soft prompt shape' + str(soft_prompt.shape))
        # print(f'soft prompt batch shape' + str(soft_prompt_batch.shape))
        projected_prompt_batch = weights.unsqueeze(1).unsqueeze(2).expand_as(basis) * basis
        projected_prompt_batch = projected_prompt_batch.sum(dim=0).unsqueeze(0).repeat(batch_size, 1, 1).to(device)
        # print(f'projected prompt batch shape' + str(projected_prompt_batch.shape))
        # print(f'shapes of soft batch and input embeddings: {soft_prompt_batch.shape}, {input_embeddings.shape}')


        combined_projected_embeddings = torch.cat([projected_prompt_batch, input_embeddings_projected], dim=1)

        # Pass the combined embeddings through the model
        outputs_projected = model_projected(inputs_embeds=combined_projected_embeddings, labels=labels)

        loss_projected = outputs_projected.loss
        epoch_loss_projected += loss_projected.item()

        optimizer_projected.zero_grad()
        loss_projected.backward(retain_graph=True)
        optimizer_projected.step()

        #print(f'complete from this epoch {i}/{len(train_dataset)}', end='')
        if _ % 200 == 0:
          print(f'Epoch {epoch+1}/{epochs}, Batch {_}/{len(train_dataset)}')
          print(f'Batch Indices: ({i}, {i+1})')
          print(f'Loss: {loss_projected.item()}')
          print()

    print('Validating Epoch...')

    for _ in range(0, len(validation_dataset), batch_size):
        i = random.randint(0, len(validation_dataset) - 2)
        batch = validation_dataset[i:i+batch_size]
        if len(batch['inputs_pretokenized']) != batch_size:
          continue
        input_ids = tokenizer(batch['inputs_pretokenized'], return_tensors='pt', padding=True, truncation=True).input_ids.to(device)
        labels = tokenizer(batch['targets_pretokenized'], return_tensors='pt', padding=True, truncation=True).input_ids.to(device)

        # Get the prompt input embeddings - same if continuous or projected
        input_embeddings = model_projected.model.shared(input_ids)
        padding_size = max(0, 100 - input_embeddings.shape[1])
        input_embeddings = F.pad(input_embeddings, (0, 0, 0, padding_size), "constant", 0)
        input_embeddings_projected = torch.Tensor(input_embeddings).to(device)

        weights = learn_weights(input_embeddings)
        # print(f'predicted weights' + str(weights))
        # print(f'predicted weights shape' + str(weights.shape))
        # print(f'basis shape' + str(basis.shape))
        # print(f'input embeddings shape' + str(input_embeddings.shape))
        # print(f'soft prompt shape' + str(soft_prompt.shape))
        # print(f'soft prompt batch shape' + str(soft_prompt_batch.shape))
        projected_prompt_batch = weights.unsqueeze(1).unsqueeze(2).expand_as(basis) * basis
        projected_prompt_batch = projected_prompt_batch.sum(dim=0).unsqueeze(0).repeat(batch_size, 1, 1).to(device)
        # print(f'projected prompt batch shape' + str(projected_prompt_batch.shape))
        # print(f'shapes of soft batch and input embeddings: {soft_prompt_batch.shape}, {input_embeddings.shape}')

        combined_projected_embeddings = torch.cat([projected_prompt_batch, input_embeddings_projected], dim=1)
        outputs_projected = model_projected(inputs_embeds=combined_projected_embeddings, labels=labels)

        loss_validation = outputs_projected.loss
        epoch_loss_validation += loss_validation.item()

    epoch_loss_projected /= len(list(range(0, len(train_dataset), batch_size)))
    epoch_loss_validation /= len(list(range(0, len(validation_dataset), batch_size)))

    print(f'Epoch Validation Loss: {epoch_loss_validation} \n', end='')
    print()

    projected_losses.append(epoch_loss_projected)
    validation_losses.append(epoch_loss_validation)

    # Create a DataFrame with the loss values
    n = len(projected_losses)

    # Create the plot
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, n + 1), projected_losses, label='Training')
    plt.plot(range(1, n + 1), validation_losses, label='Validation')
    plt.title('Normalized Loss v.s. Epoch for a Learned Linear Combination Model')
    plt.title(f'Epochs: {epochs}, Batch Size: {batch_size}')
    plt.legend()

    # Save the plot as a png file
    print(f'Saveing figure...')
    plt.savefig(f'loss_plot_epoch_{epoch+1}.png')
    plt.show()


Training...
Epoch 1/3, Batch 0/9741
Batch Indices: (1056, 1057)
Loss: 10.170181274414062

Epoch 1/3, Batch 200/9741
Batch Indices: (9131, 9132)
Loss: 9.458856582641602

Epoch 1/3, Batch 400/9741
Batch Indices: (7609, 7610)
Loss: 10.186385154724121

Epoch 1/3, Batch 600/9741
Batch Indices: (7242, 7243)
Loss: 8.974509239196777

Epoch 1/3, Batch 800/9741
Batch Indices: (6588, 6589)
Loss: 9.574487686157227

Epoch 1/3, Batch 1000/9741
Batch Indices: (1447, 1448)
Loss: 8.880127906799316

Epoch 1/3, Batch 1200/9741
Batch Indices: (4693, 4694)
Loss: 7.303316116333008

Epoch 1/3, Batch 1400/9741
Batch Indices: (3127, 3128)
Loss: 8.906057357788086

Epoch 1/3, Batch 1600/9741
Batch Indices: (4119, 4120)
Loss: 8.081753730773926

Epoch 1/3, Batch 1800/9741
Batch Indices: (8173, 8174)
Loss: 7.923852443695068

Epoch 1/3, Batch 2000/9741
Batch Indices: (5300, 5301)
Loss: 8.849320411682129

