In [1]:
from diffusion_transformer import TokenDiffusionModel
import urllib.request
import torch
import random
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'

# urllib.request.urlretrieve(url, 'input.txt')


# Read the text file
with open('input.txt', 'r') as file:
    text = file.read()

# Tokenize the text into words
words = [x for x in  text.lower().split() if len(x) > 0]

unqiue_words = list(set(words))

In [3]:
# Hyperparameters
vocab_size = len(unqiue_words) + 1  # Size of the vocabulary plus mask
embedding_dim = 128  # Size of embeddings (e.g., BERT-like model)
hidden_dim = 128  # Transformer hidden layer size
num_iterations = 100  # Number of iterative refinement steps
max_seq_len = 64  # Maximum sequence length

# Instantiate the model
model = TokenDiffusionModel(vocab_size, embedding_dim, hidden_dim, num_iterations, max_seq_len).to(device)
# Print the model architecture
print(model)

TokenDiffusionModel(
  (embedding): Embedding(23642, 128)
  (transformer_decoder_layer): TransformerDecoderLayer(
    (self_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (multihead_attn): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
    )
    (linear1): Linear(in_features=128, out_features=128, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
    (linear2): Linear(in_features=128, out_features=128, bias=True)
    (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    (dropout1): Dropout(p=0.1, inplace=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (dropout3): Dropout(p=0.1, inplace=False)
  )
  (transformer_decoder): TransformerDecoder(
    (layers): ModuleList(
      (0

In [4]:
#make a dict of words to index

word_to_index = {word: i for i, word in enumerate(unqiue_words)}

# create a list of indices

indices = [word_to_index[word] for word in words]


sequences = [indices[i:i+max_seq_len] for i in range(len(indices)-max_seq_len)]


In [5]:
mask_id = 23641

def add_noise_to_sequence(sequence, noise_type='mask', mask_prob=0.25):
    noisy_sequence = list(sequence)
    seq_len = len(sequence)

    #save a mask and target token id
    target = [-1] * seq_len
    
    if noise_type == 'mask':
        for i in range(seq_len):
            if random.random() < mask_prob:
                # Replace with random token
                target[i] = noisy_sequence[i]
                noisy_sequence[i] = mask_id
    
    # elif noise_type == 'shuffle':
    #     indices = list(range(seq_len))
    #     random.shuffle(indices)
    #     noisy_sequence = [sequence[i] for i in indices]
    
    # elif noise_type == 'replace':
    #     for i in range(seq_len):
    #         if random.random() < mask_prob:
    #             # Replace with random token
    #             noisy_sequence[i] = random.randint(0, vocab_size - 1)
    
    return torch.tensor(noisy_sequence), torch.tensor(target)

targets = []
noisy_sequences = []

for i in range(len(sequences)):
    noisy_sequence, target = add_noise_to_sequence(sequences[i], noise_type='mask', mask_prob=0.25)
    targets.append(target)
    noisy_sequences.append(noisy_sequence)


In [10]:
#make these into batches
batch_size = 32

batches = []
for i in range(0, len(noisy_sequences), batch_size):
    batch = noisy_sequences[i:i+batch_size]
    target = targets[i:i+batch_size]
    batches.append((torch.vstack(batch).to(device), torch.vstack(target).to(device)))

In [23]:
print(target.shape)
print(output.shape)

torch.Size([32, 64])
torch.Size([32, 64, 23642])


In [24]:
#start training

criterion = torch.nn.CrossEntropyLoss(ignore_index=-1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for sequence, target in batches:
    optimizer.zero_grad()
    output = model(sequence)
    loss = criterion(output.view(-1, vocab_size), target.view(-1))
    loss.backward()
    optimizer.step()
    print(loss.item())

10.30741024017334
9.917768478393555
9.827102661132812
9.762602806091309
