In [1]:
from dataset import TinyShakespeareDataset
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

from diffusion_transformer import TextDiffusionModel
from loss import DiffusionLoss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
dataset = TinyShakespeareDataset('input.txt', seq_len=16)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)

In [3]:
next(iter(dataset))

(tensor([18, 21, 30, 31, 32,  1, 15, 21, 32, 21, 38, 17, 26, 10,  0, 14]),
 tensor([18, 21, 30, 31, 32,  1, 15, 21, 32, 21, 38, 17, 26, 10,  0, 14]))

In [4]:
# Hyperparameters
vocab_size = dataset.vocab_size  # Size of the vocabulary plus mask
embedding_dim = 64  # Size of embeddings (e.g., BERT-like model)
hidden_dim = 64  # Transformer hidden layer size
num_iterations = 20  # Number of iterative refinement steps
max_seq_len = 16  # Maximum sequence length
num_layers = 4
nhead = 4
# self, vocab_size, embedding_dim, hidden_dim, num_layers, nhead, max_seq_len, dropout=0.1
# Instantiate the model
#self, vocab_size, embed_dim, max_seq_length, num_steps, num_heads=8
model = TextDiffusionModel(vocab_size, embedding_dim, max_seq_len, num_iterations, nhead).to(device)

print(model)

TextDiffusionModel(
  (embedding): Embedding(39, 64)
  (transformer_encoder): TransformerEncoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
    )
    (linear1): Linear(in_features=64, out_features=2048, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=2048, out_features=64, bias=True)
    (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
  )
  (guided_attention): GuidedAttentionLayer(
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
    )
  )
  (noise_predictor): Linear(in_features=64, out_features=64, bias=True)
)


In [5]:

# Initialize loss
criterion = DiffusionLoss(model)

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


In [6]:
losses = []
num_epochs = 10
num_steps = 20
batch_size = 32


for epoch in range(num_epochs):
    epoch_loss = 0
    for batch, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        batch = batch.to(device)

        # Sample random timesteps
        t = torch.randint(0, num_steps, (batch_size,), device=batch.device)
        # Forward pass
        
        predicted_noise, actual_noise = model(batch, t)
        optimizer.zero_grad()

        loss = criterion(predicted_noise, actual_noise )
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(dataloader)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")

# Plot the loss curve
plt.figure(figsize=(10, 5))
plt.plot(range(1, num_epochs+1), losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.show()


  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Epoch 1/10:  22%|██▏       | 7826/34855 [01:30<05:13, 86.15it/s] 


KeyboardInterrupt: 

In [18]:
import torch

# Assuming you have a trained model loaded
# model = torch.load('path_to_trained_model.pth')
# or
# model.load_state_dict(torch.load('path_to_state_dict.pth'))

# Set model to evaluation mode
model.eval()
model = model.cpu()

# Number of diffusion steps
num_steps = model.num_steps
batch_size = 1  # Single sample generation
seq_length = model.position_encoding.size(1)
vocab_size = model.embedding.num_embeddings

# Storage for generated sequences
generated_sequences = []

# Perform inference
with torch.no_grad():
    for _ in range(100):  # Number of samples to generate
        print(_)
        # Initialize with random tokens
        initial_tokens = torch.tensor([dataset.char_to_idx[x] for x in  "because i am the" ], dtype=torch.long).unsqueeze(0)
        
        # Convert initial tokens to float embeddings
        embeddings = model.embedding(initial_tokens) + model.position_encoding[:, :seq_length, :]

        # Perform progressive denoising
        tokens = embeddings.clone()
        for t in range(num_steps - 1, -1, -1):
            t_tensor = torch.tensor([t], dtype=torch.float32)

            # Pass through the model
            predicted_noise, _ = model.generate(tokens, t_tensor)
            
            # Compute the alpha for this step
            alphas = model.noise_schedule(t_tensor).view(-1, 1, 1)
            
            # Denoise the embeddings
            tokens = (tokens - (1 - alphas).sqrt() * predicted_noise) / alphas.sqrt()

        # Convert final embeddings back to token indices
    final_embeddings = tokens.squeeze(0)  # (seq_len, embed_dim)
    vocab_embeddings = model.embedding.weight
 # Compute the pairwise cosine similarity
    similarities = F.cosine_similarity(final_embeddings.unsqueeze(1), vocab_embeddings.unsqueeze(0), dim=-1)
    
    # Get the indices of the most similar tokens
    final_tokens = similarities.argmax(dim=-1).cpu().numpy()

# Optionally, convert final tokens back to text if you have a tokenizer
# final_texts = [tokenizer.decode(seq) for seq in generated_sequences]
# for i, text in enumerate(final_texts):
#     print(f"Generated Text {i+1}: {text}")

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99


In [20]:
print("Generated tokens:", final_tokens)
print("Decoded text:", "".join([dataset.idx_to_char[idx] for idx in final_tokens]))


Generated tokens: [12 35 12  1 22 37 20  1 12  1 20  1  1 20 27 11]
Decoded text: ;v;
ixg
;
g

gn:


In [None]:
# Compute the pairwise cosine similarity
    similarities = F.cosine_similarity(final_embeddings.unsqueeze(1), vocab_embeddings.unsqueeze(0), dim=-1)
    
    # Get the indices of the most similar tokens
    final_tokens = similarities.argmax(dim=-1).cpu().numpy()

39