In [113]:
!pip install transformers

import torch
import torch.nn as nn
import torch.optim as optim
import math
from transformers import BertTokenizer, BertModel
from torch.utils.data import DataLoader, TensorDataset

You should consider upgrading via the '/Library/Frameworks/Python.framework/Versions/3.9/bin/python3.9 -m pip install --upgrade pip' command.[0m


In [114]:
# Load pre-trained BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')

# Define the short stories
stories = ["The cat sat on the mat", "There is gold", "Somewhere over the rainbow", "Random sentence"]

# Tokenize the stories and convert them to input IDs
input_ids = []
for story in stories:
    encoded_story = tokenizer.encode(story, add_special_tokens=True)
    input_ids.append(encoded_story)

print("input_ids", input_ids)

input_ids [[101, 1996, 4937, 2938, 2006, 1996, 13523, 13523, 13523, 102], [101, 2045, 2003, 2751, 102], [101, 4873, 2058, 1996, 10098, 102], [101, 6721, 6251, 102]]


In [115]:
# Generate BERT embeddings for each story
story_embeddings = []
with torch.no_grad():
    for ids in input_ids:
        embeddings = bert_model(torch.tensor([ids]))[0]
        story_embeddings.append(embeddings)

print("story_embeddings", story_embeddings)

# Determine the maximum sequence length among the stories
max_length = max(len(ids) for ids in input_ids)

story_embeddings [tensor([[[-0.2599,  0.0782,  0.0374,  ..., -0.1910,  0.2220,  0.2473],
         [-0.4218, -0.1978, -0.2823,  ..., -0.2741,  0.9405, -0.5086],
         [-0.3659, -0.0945,  0.2970,  ..., -0.3415,  0.6979,  0.6504],
         ...,
         [ 0.1764, -0.0344,  0.4185,  ..., -0.1668,  0.0158,  0.4952],
         [-0.3249, -0.4096,  0.1259,  ...,  0.5783,  0.6253, -0.2562],
         [ 0.6688,  0.0528, -0.3455,  ...,  0.1711, -0.3499, -0.3942]]]), tensor([[[-0.3443,  0.3713, -0.3405,  ..., -0.1429,  0.2200,  0.4851],
         [-0.5444, -0.0672, -0.4171,  ..., -0.1206,  0.8002, -0.0201],
         [-0.0833, -0.2276, -0.2673,  ...,  0.0262,  0.2740,  0.7333],
         [-0.7732, -0.1943, -0.2901,  ...,  0.4301,  0.5311, -0.0250],
         [ 0.9799,  0.1183, -0.2731,  ...,  0.2572, -0.6214, -0.2615]]]), tensor([[[-0.0177,  0.0610,  0.0319,  ..., -0.2689,  0.0759,  0.0701],
         [-0.3163,  0.4928, -0.4013,  ..., -0.4979,  0.2614,  0.7672],
         [ 0.2041,  0.0082,  0.0941,  .

In [116]:
# Pad the story embeddings
padded_story_embeddings = []
for embeddings in story_embeddings:
    padded_embeddings = torch.zeros(1, max_length, embeddings.shape[-1])
    padded_embeddings[0, :embeddings.shape[1], :] = embeddings[0]
    padded_story_embeddings.append(padded_embeddings)

# Define the target tokens for each story
target_tokens = []
for ids in input_ids:
    target_ids = ids[1:]  # Shift the input tokens by one position
    target_tokens.append(torch.tensor(target_ids))

print("target_tokens", target_tokens)

# Pad the target tokens
padded_target_tokens = []
for tokens in target_tokens:
    padded_tokens = torch.zeros(max_length, dtype=torch.long)
    padded_tokens[:len(tokens)] = tokens
    padded_target_tokens.append(padded_tokens)

target_tokens [tensor([ 1996,  4937,  2938,  2006,  1996, 13523, 13523, 13523,   102]), tensor([2045, 2003, 2751,  102]), tensor([ 4873,  2058,  1996, 10098,   102]), tensor([6721, 6251,  102])]


In [117]:
class TransformerDecoder(nn.Module):
    def __init__(self, d_model, nhead, num_layers, dim_feedforward, vocab_size):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, dim_feedforward)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, target, memory):
        target = self.embedding(target) * math.sqrt(d_model)
        target = self.pos_encoder(target)
        output = self.transformer_decoder(target, memory)
        output = self.fc(output)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return x

# Hyperparameters
d_model = 768  # Same as BERT's hidden size
nhead = 8
num_layers = 6
dim_feedforward = 2048
vocab_size = len(tokenizer.vocab)
print("vocab_size", vocab_size)

vocab_size 30522


In [118]:
# Create the decoder
decoder = TransformerDecoder(d_model, nhead, num_layers, dim_feedforward, vocab_size)

# Training hyperparameters
learning_rate = 0.0001
num_epochs = 10
batch_size = 2

# Create a dataset and data loader
dataset = TensorDataset(torch.stack(padded_story_embeddings), torch.stack(padded_target_tokens))
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [119]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    for batch in dataloader:
        optimizer.zero_grad()
        
        # Get the decoder inputs and target tokens for the current batch
        decoder_inputs, target_tokens = batch
        
        # Reshape the decoder inputs and target tokens
        decoder_inputs = decoder_inputs.view(batch_size, max_length, -1)
        target_tokens = target_tokens.view(batch_size, -1)
        
        # Forward pass
        decoder_outputs = decoder(target_tokens, decoder_inputs)
        
        # Compute the loss
        loss = criterion(decoder_outputs.view(-1, vocab_size), target_tokens.view(-1))
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [1/10], Loss: 8.2134
Epoch [2/10], Loss: 3.5438
Epoch [3/10], Loss: 4.9063
Epoch [4/10], Loss: 4.0790
Epoch [5/10], Loss: 1.8305
Epoch [6/10], Loss: 1.5955
Epoch [7/10], Loss: 1.6217
Epoch [8/10], Loss: 2.0823
Epoch [9/10], Loss: 1.3523
Epoch [10/10], Loss: 1.3933


In [121]:
# Text generation
def generate_text(prompt, max_length=50):
    # Tokenize the prompt
    prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
    prompt_tensor = torch.tensor([prompt_ids])

    # Generate BERT embeddings for the prompt
    with torch.no_grad():
        prompt_embeddings = bert_model(prompt_tensor)[0]

    # Initialize the generated sequence with the prompt
    generated_seq = prompt_ids

    for _ in range(max_length):
        # Get the decoder input and target tokens
        decoder_input = prompt_embeddings[:, :len(generated_seq), :]
        target_token = torch.tensor([generated_seq])

        # Pad the decoder input and target tokens
        decoder_input_padded = torch.zeros(1, len(generated_seq), decoder_input.shape[-1])
        decoder_input_padded[0, :decoder_input.shape[1], :] = decoder_input
        target_token_padded = torch.zeros(len(generated_seq), dtype=torch.long)
        target_token_padded[:len(target_token[0])] = target_token[0]

        # Forward pass
        decoder_output = decoder(target_token_padded.unsqueeze(0), decoder_input_padded)

        # Get the predicted token
        predicted_token = decoder_output.argmax(dim=-1)[0, -1].item()

        # Append the predicted token to the generated sequence
        generated_seq.append(predicted_token)

        # Stop generation if the end-of-sequence token is predicted
        if predicted_token == tokenizer.sep_token_id:
            break

    # Decode the generated sequence
    generated_text = tokenizer.decode(generated_seq)

    return generated_text

# Example usage
prompt = "The cat"
generated_story = generate_text(prompt)
print(generated_story)

[CLS] the cat [SEP] [SEP]
