In [50]:
import torch
from transformers import DistilBertForMaskedLM, DistilBertTokenizer
from datasets import load_dataset
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

# Define a collate function to pad your sequences
def collate_fn(batch):
    inputs = [item['inputs'] for item in batch]
    targets = [item['targets'] for item in batch]

    # Tokenize the inputs and targets
    inputs = [tokenizer.encode(input, return_tensors='pt')[0] for input in inputs]
    targets = [tokenizer.encode(target, return_tensors='pt')[0] for target in targets]

    # Pad the sequences
    max_length_inputs = max([input.size(0) for input in inputs])
    max_length_targets = max([target.size(0) for target in targets])
    inputs = pad_sequence([torch.cat([input, input.new_zeros(max_length_inputs - input.size(0))]) for input in inputs], batch_first=True, padding_value=tokenizer.pad_token_id)
    targets = pad_sequence([torch.cat([target, target.new_zeros(max_length_targets - target.size(0))]) for target in targets], batch_first=True, padding_value=tokenizer.pad_token_id)

    return {'inputs': inputs, 'targets': targets}

# Load the model and tokenizer
model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Load the dataset
dataset = load_dataset('bigscience/P3', 'cos_e_v1.11_aligned_with_common_sense')

# Use the DataLoader to batch your dataset
batch_size = 10
dataloader = DataLoader(dataset['train'], batch_size=batch_size, collate_fn=collate_fn)

# Define your soft prompt
soft_prompt = "Try to answer this: "

# Tokenize the soft prompt
soft_prompt_tokens = tokenizer.encode(soft_prompt, return_tensors='pt')

# Freeze the model parameters
for param in model.parameters():
    param.requires_grad = True

# Initialize the soft prompt as a learnable parameter
soft_prompt_param = torch.nn.Parameter(soft_prompt_tokens.float())
soft_prompt_param.requires_grad = True

# Define your optimizer to only update the soft prompt parameter
optimizer = torch.optim.Adam([soft_prompt_param], lr=0.01)
epochs = 1

# Training loop
print(dataset['train'][:10])
for epoch in range(epochs):
    for batch in dataloader:
        # Concatenate the soft prompt with the input
        input_ids_batch = torch.Tensor(batch['inputs'])
        soft_prompt_param_repeated = soft_prompt_param.repeat(input_ids_batch.size(0), 1)
        input_ids = torch.cat([soft_prompt_param_repeated, input_ids_batch.long()], dim=-1)
        input_ids = input_ids.long()

        # Ensure the targets tensor has the same size as the input_ids tensor
        targets = torch.tensor(batch['targets']).long()
        targets = F.pad(targets, (0, input_ids.size(1) - targets.size(1)))

        # Forward pass
        outputs = model(input_ids=input_ids, labels=targets)
        loss = outputs.loss

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

{'inputs': [[947, 31, 7, 3, 9, 822, 11, 3, 9, 360, 487, 4269, 10, 1593, 10, 96, 7238, 33, 335, 16981, 30, 46, 8947, 2195, 5, 5245, 1590, 326, 5, 852, 132, 33, 3, 4, 16981, 535, 363, 19, 48, 46, 677, 13, 58, 29403, 71, 10, 2447, 6, 7089, 484, 6, 2004, 1530, 6, 7270, 682, 6, 18076, 1615, 19, 96, 3357, 107, 682, 121, 46, 1525, 7901, 15, 26, 28, 936, 1017, 1254, 58], [947, 31, 7, 3, 9, 822, 11, 3, 9, 360, 487, 4269, 10, 1593, 10, 71, 1079, 19, 3, 9, 8524, 51, 5, 8718, 114, 8, 26524, 6, 3, 88, 1342, 1084, 48, 1843, 13, 5127, 3620, 5, 2840, 405, 3, 88, 619, 58, 29403, 71, 10, 2601, 14089, 6, 2608, 6, 2412, 2478, 6, 4716, 6, 4716, 1615, 19, 96, 9818, 121, 46, 1525, 7901, 15, 26, 28, 936, 1017, 1254, 58], [947, 31, 7, 3, 9, 822, 11, 3, 9, 360, 487, 4269, 10, 1593, 10, 71, 1282, 568, 1747, 385, 701, 30, 271, 5057, 6, 6922, 406, 7140, 5167, 42, 271, 125, 58, 29403, 71, 10, 1287, 6, 23868, 6, 3331, 6, 11766, 6, 437, 60, 1615, 19, 96, 27296, 60, 121, 46, 1525, 7901, 15, 26, 28, 936, 1017, 1254, 58

  targets = torch.tensor(batch['targets']).long()


IndexError: index out of range in self