In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import string
import random
import math

### Select Device

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cpu


### Define Transformer Achitecture

In [15]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, input_dim, embed_dim, nhead, nhid, nlayers, output_dim, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.embed_dim = embed_dim
        self.pos_encoder = PositionalEncoding(embed_dim, dropout)
        self.encoder = nn.Embedding(input_dim, embed_dim)
        self.transformer = nn.Transformer(embed_dim, nhead, nlayers, nlayers, nhid, dropout)
        self.decoder = nn.Linear(embed_dim, output_dim)
        
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.embed_dim)
        src = self.pos_encoder(src)
        output = self.transformer(src, src, src_mask)
        output = self.decoder(output)
        return output

### Process Input

In [16]:
all_chars = string.printable
n_chars = len(all_chars)
file = open('../Data/shakespeare.txt').read()
file_len = len(file)

### Encode Input

In [17]:
def get_random_seq():
    seq_len = 128  # The length of an input sequence.
    start_index = random.randint(0, file_len - seq_len)
    end_index = start_index + seq_len + 1
    return file[start_index:end_index]

def seq_to_index(seq):
    tensor = torch.zeros(len(seq), dtype=torch.long)
    for t, char in enumerate(seq):
        tensor[t] = all_chars.index(char) + 1
    return tensor

def get_input_and_target():
    seq = get_random_seq()
    input = seq_to_index(seq[:-1])  # Input is represented in index.
    # input = seq_to_onehot(seq[:-1])
    target = seq_to_index(seq[1:])  # Target is represented in index.
    return input, target

def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones(sz, sz) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [None]:
def generate_text(model,target,start_seq='Wha', gen_len=10,temperature=1.0):
    model.eval()  # Set the model to evaluation mode
    
    indices = ""
    tensor = start_seq[0]
    for char in tensor:
        indices += all_chars[char-1] 
        
    input_seq = start_seq
    seen_output = []
    
    generated_text = start_seq
    
    for _ in range(gen_len):
        with torch.no_grad():
            output = model(input_seq, target)
        
        # Get the last output and convert to probabilities
        next_char_logits = output[-1, 0, :] # Shape: (output_dim)
        next_char_probs = torch.softmax(next_char_logits, dim=-1)/temperature

        predicted_char_index = torch.argmax(next_char_probs)
        predicted_char = all_chars[predicted_char_index]
        
        indices += predicted_char
        
        # Append the next character to the input sequence
        next_char_tensor = torch.tensor([predicted_char_index], dtype=torch.long).unsqueeze(1).to(device)
        input_seq = torch.cat([input_seq, next_char_tensor], dim=1)[:, 1:]
    
    return indices

In [None]:
# input_dim = len(all_chars)
input_dim = 512
embed_dim = 128
nhead = 2
nhid = 256
nlayers = 2
output_dim = len(all_chars)
dropout = 0.2

model = TransformerModel(input_dim, embed_dim, nhead, nhid, nlayers, output_dim, dropout)



In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
iters = 200

epochs = 20
for epoch in range(epochs):
    model.train()
    total_loss = 0.
    for i in range(iters):  # Adjust the number of batches
        input, target = get_input_and_target()
        # Add batch dimension, shape: (seq_len, batch_size=1)
        input = input.unsqueeze(1)
        target = target.unsqueeze(1)
        
        src_mask = generate_square_subsequent_mask(input.size(0)).to(input.device)
        
        optimizer.zero_grad()
        # print("input shape",input.shape)
        # print("src mask shape",src_mask.shape)
        # input = input.permute(1,2,0)
        # print("input shape",input.shape)
        # print("src mask shape",src_mask.shape)
        output = model(input, src_mask)
        
        # Reshape output to (seq_len * batch_size, output_dim) and target to (seq_len * batch_size)
        output = output.view(-1, output_dim)
        target = target.view(-1)
        
        loss = criterion(output, target)
        total_loss += loss
        loss.backward()
        optimizer.step()
        if i == epochs - 1:
            print("training")
    print(f'Epoch {epoch+1}, Loss: {total_loss / iters}')
    print("Generated output:", generate_text(model))

In [None]:
# total_loss = 0
# iters = 500
# for i in range(iters):  # Adjust the number of batches
#     input, target = get_input_and_target()
#     # Add batch dimension, shape: (seq_len, batch_size=1)
#     input = input.unsqueeze(1)
#     target = target.unsqueeze(1)

#     src_mask = generate_square_subsequent_mask(input.size(0)).to(input.device)

#     optimizer.zero_grad()
#     output = model(input, src_mask)

#     # Reshape output to (seq_len * batch_size, output_dim) and target to (seq_len * batch_size)
#     output = output.view(-1, output_dim)
#     target = target.view(-1)

#     loss = criterion(output, target)
#     loss.backward()
#     optimizer.step()

#     total_loss += loss.item()
#     print(f'Epoch {epoch+1}, Loss: {total_loss / 100}')

In [None]:
# # Generate a text sequence starting with a given seed
# seed_text = "The"
# generated_text = generate_text(model, seed_text, max_len=500)
# print(generated_text)
rand_input = get_random_seq() 
print("Input",rand_input)
print("Output")
print(generate_text(model, rand_input, max_len=100))