In [60]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import GPT2Tokenizer
import string

In [61]:
# Load your Shakespeare quotes from input.txt
with open("input.txt", "r", encoding="utf-8") as file:
    shakespeare_quotes = file.read()

# Define a simple character-level tokenizer
class CharTokenizer:
    def __init__(self):
        self.char2index = {char: i for i, char in enumerate(string.printable)}
        self.index2char = {i: char for i, char in enumerate(string.printable)}

    def tokenize(self, text):
        return [self.char2index[char] for char in text]

    def detokenize(self, indices):
        return ''.join([self.index2char[i] for i in indices])

# Tokenize the text
char_tokenizer = CharTokenizer()
tokenized_text = char_tokenizer.tokenize(shakespeare_quotes)


In [62]:

# Create a custom dataset
class ShakespeareDataset(Dataset):
    def __init__(self, data, seq_length=32):
        self.data = data
        self.seq_length = seq_length

    def __len__(self):
        return len(self.data) - self.seq_length

    def __getitem__(self, idx):
        return self.data[idx:idx + self.seq_length]

# Pad collate function for DataLoader
def pad_collate(batch):
    return torch.tensor(batch)

In [63]:
# Hyperparameters
seq_length = 32
batch_size = 64
epochs = 5
lr = 0.001

In [64]:
# Create dataset and dataloader
dataset = ShakespeareDataset(tokenized_text, seq_length=seq_length)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=pad_collate, shuffle=True)


In [65]:
len(dataset)

2644

In [66]:
# Define a simple Transformer model
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_layers=6):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model, nhead=nhead, num_encoder_layers=num_layers
        )
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, src):
        src = self.embedding(src)
        src = src.permute(1, 0, 2)  # Change from (seq_len, batch, embed_dim) to (batch, seq_len, embed_dim)
        output = self.transformer(src, src)  # Use src as both source and target
        output = self.fc(output[-1, :, :])  # Take the last token's output
        return output



In [67]:
# Initialize model, criterion, and optimizer
model = TransformerModel(vocab_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr)


In [68]:
# Training loop
for epoch in range(epochs):
    total_loss = 0
    for batch in dataloader:
        optimizer.zero_grad()
        output = model(batch)
        loss = criterion(output, batch[:, -1])  # Predict the last token in the sequence
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}")

# Save the trained model
torch.save(model.state_dict(), "shakespeare_transformer_model.pth")



Epoch 1, Loss: 4.213603405725388
Epoch 2, Loss: 3.3028276420774914
Epoch 3, Loss: 3.2914816708791825
Epoch 4, Loss: 3.2923165503002347
Epoch 5, Loss: 3.285157317206973


In [None]:
# Inference
def generate_quote(model, tokenizer, prompt, max_length=50):
    model.eval()
    prompt = tokenizer.encode(prompt, return_tensors="pt").squeeze()
    
    with torch.no_grad():
        for _ in range(max_length):
            output = model(prompt.unsqueeze(0))
            predicted_token = torch.argmax(output)
            prompt = torch.cat((prompt, predicted_token.unsqueeze(0)))

    generated_text = tokenizer.decode(prompt, skip_special_tokens=True)
    return generated_text

# Load the trained model for inference
inference_model = TransformerModel(vocab_size)
inference_model.load_state_dict(torch.load("shakespeare_transformer_model.pth"))
inference_model.eval()

# Example usage
prompt = "To be gone"
generated_quote = generate_quote(inference_model, tokenizer, prompt)
print(generated_quote)