In [1]:
from diffusion_transformer import TokenDiffusionModel
from dataset import TinyShakespeareDataset
import urllib.request
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import random
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
dataset = TinyShakespeareDataset('input.txt', seq_len=32)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [3]:
# Hyperparameters
vocab_size = dataset.vocab_size  # Size of the vocabulary plus mask
embedding_dim = 32  # Size of embeddings (e.g., BERT-like model)
hidden_dim = 32  # Transformer hidden layer size
num_iterations = 200  # Number of iterative refinement steps
max_seq_len = 32  # Maximum sequence length
num_layers = 1
nhead = 4
# self, vocab_size, embedding_dim, hidden_dim, num_layers, nhead, max_seq_len, dropout=0.1
# Instantiate the model
model = TokenDiffusionModel(vocab_size, embedding_dim, hidden_dim, num_layers, nhead, num_iterations, max_seq_len).to(device)

print(model)

TokenDiffusionModel(
  (embedding): Embedding(65, 32)
  (transformer_decoder_layer): TransformerDecoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=32, out_features=32, bias=True)
    )
    (linear1): Linear(in_features=32, out_features=32, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=32, out_features=32, bias=True)
    (norm1): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((32,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (dropout3): Dropout(p=0.1, inplace=False)
  )
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-1): 2 x Transf

In [4]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [5]:
num_epochs = 10

for epoch in range(num_epochs):
    for batch_idx, (input_tokens, target_tokens) in enumerate(dataloader):
        optimizer.zero_grad()
        input_tokens = input_tokens.to(device)
        # Forward pass through the model
        logits = model(input_tokens)

        # Reshape logits and targets for loss computation
        logits = logits.view(-1, vocab_size)
        target_tokens = target_tokens.view(-1)

        # Compute loss
        loss = criterion(logits.cpu(), target_tokens)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        if batch_idx % 10 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}], Loss: {loss.item():.4f}')


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


Epoch [1/10], Batch [0], Loss: 4.4782
Epoch [1/10], Batch [10], Loss: 4.3291
Epoch [1/10], Batch [20], Loss: 4.1472
Epoch [1/10], Batch [30], Loss: 4.0476
Epoch [1/10], Batch [40], Loss: 4.0059
Epoch [1/10], Batch [50], Loss: 3.9395
Epoch [1/10], Batch [60], Loss: 3.8836
Epoch [1/10], Batch [70], Loss: 3.8634
Epoch [1/10], Batch [80], Loss: 3.8099
Epoch [1/10], Batch [90], Loss: 3.7815
Epoch [1/10], Batch [100], Loss: 3.7451
Epoch [1/10], Batch [110], Loss: 3.6970
Epoch [1/10], Batch [120], Loss: 3.6625
Epoch [1/10], Batch [130], Loss: 3.7118
Epoch [1/10], Batch [140], Loss: 3.7058
Epoch [1/10], Batch [150], Loss: 3.6446
Epoch [1/10], Batch [160], Loss: 3.6144
Epoch [1/10], Batch [170], Loss: 3.6059
Epoch [1/10], Batch [180], Loss: 3.5518
Epoch [1/10], Batch [190], Loss: 3.5723
Epoch [1/10], Batch [200], Loss: 3.5181
Epoch [1/10], Batch [210], Loss: 3.5595
Epoch [1/10], Batch [220], Loss: 3.4487
Epoch [1/10], Batch [230], Loss: 3.4985
Epoch [1/10], Batch [240], Loss: 3.4961
Epoch [1/10

KeyboardInterrupt: 

In [11]:
def generate_sequence(model, start_text, length=100):
    model.eval()
    input_tokens = torch.tensor([dataset.char_to_idx[c] for c in start_text]).unsqueeze(0).to(device)
    
    with torch.no_grad():
        for _ in range(length):
            logits = model(input_tokens)
            logits = logits[:, -1, :]  # Get logits of the last token in the sequence
            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            input_tokens = torch.cat([input_tokens, next_token], dim=1)

    return ''.join([dataset.idx_to_char[idx.item()] for idx in input_tokens.squeeze()])

# Generate text
print(generate_sequence(model, start_text="ROMEO: ", length=20))


ROMEO: VdqVh$RXltnusNZivewn
