In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import string
import random
import math

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [13]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_dim, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-math.log(10000.0) / embed_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):
    def __init__(self, input_dim, embed_dim, nhead, nhid, nlayers, output_dim, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.embed_dim = embed_dim
        self.pos_encoder = PositionalEncoding(embed_dim, dropout)
        self.encoder = nn.Embedding(input_dim, embed_dim)
        self.transformer = nn.Transformer(embed_dim, nhead, nlayers, nlayers, nhid, dropout)
        self.decoder = nn.Linear(embed_dim, output_dim)
        
        self.init_weights()
        
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.embed_dim)
        src = self.pos_encoder(src)
        output = self.transformer(src, src, src_mask)
        output = self.decoder(output)
        return output

In [14]:
all_chars = string.printable
n_chars = len(all_chars)
file = open('../Data/shakespeare.txt').read()
file_len = len(file)

In [15]:
def get_random_seq():
    seq_len = 128  # The length of an input sequence.
    start_index = random.randint(0, file_len - seq_len)
    end_index = start_index + seq_len + 1
    return file[start_index:end_index]

def seq_to_onehot(seq):
    tensor = torch.zeros(len(seq), n_chars, dtype=torch.float)
    for t, char in enumerate(seq):
        index = all_chars.index(char)
        tensor[t][index] = 1.0
    return tensor

def seq_to_index(seq):
    tensor = torch.zeros(len(seq), dtype=torch.long)
    for t, char in enumerate(seq):
        tensor[t] = all_chars.index(char)
    return tensor

def get_input_and_target():
    seq = get_random_seq()
    input = seq_to_index(seq[:-1])  # Input is represented in index.
    target = seq_to_index(seq[1:])  # Target is represented in index.
    return input, target

def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones(sz, sz) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask

In [16]:
# def generate_text(model, start_seq, max_len=100):
#     model.eval()
#     with torch.no_grad():
#         input = seq_to_index(start_seq)
#         input = input.unsqueeze(1)  # Adding batch dimension
#         generated = input
        
#         for _ in range(max_len):
#             src_mask = generate_square_subsequent_mask(generated.size(0)).to(generated.device)
#             output = model(generated, src_mask)
#             next_char = torch.argmax(output[-1, :], dim=-1)
#             next_char = next_char.unsqueeze(0).unsqueeze(1)  # Shape: (1, 1)
#             generated = torch.cat((generated, next_char), dim=0)  # Concatenate along the sequence dimension
            
#         generated_seq = ''.join([all_chars[idx] for idx in generated.squeeze().tolist()])
#     return generated_seq
# def generate_text(model, start_seq, max_len=100):
#     model.eval()
#     with torch.no_grad():
#         input = seq_to_index(start_seq)
#         input = input.unsqueeze(1)  # Adding batch dimension
#         generated = input
        
#         for _ in range(max_len):
#             src_mask = generate_square_subsequent_mask(generated.size(0)).to(generated.device)
#             output = model(generated, src_mask)
#             next_char = torch.argmax(output[-1, :], dim=-1)
#             next_char = next_char.unsqueeze(0).unsqueeze(1)  # Shape: (1, 1)
#             generated = torch.cat((generated, next_char), dim=0)  # Concatenate along the sequence dimension
            
#         generated_seq = ''.join([all_chars[idx] for idx in generated.squeeze().tolist()])
#     return generated_seq
def generate_text(model, start_seq, max_len=100):
    model.eval()
    with torch.no_grad():
        input = seq_to_index(start_seq).unsqueeze(1)  # Shape: (seq_len, batch_size=1)
        generated = input

        for _ in range(max_len):
            src_mask = generate_square_subsequent_mask(generated.size(0)).to(generated.device)
            output = model(generated, src_mask)
            
            # Select the last time step's output
            next_char_logits = output[-1, 0, :]
            next_char = torch.argmax(next_char_logits, dim=-1).unsqueeze(0).unsqueeze(1)  # Shape: (1, 1)
            
            generated = torch.cat((generated, next_char), dim=0)  # Concatenate along the sequence dimension

        generated_seq = ''.join([all_chars[idx] for idx in generated.squeeze().tolist()])
    return generated_seq

def eval_step(net, init_seq='W', predicted_len=100):
    # Initialize the hidden state, input and the predicted sequence.
    model.eval()
    with torch.no_grad():
        input = seq_to_index(init_seq).unsqueeze(1)  # Shape: (seq_len, batch_size=1)
        generated = input

    # Use initial string to "build up" hidden state.
#     for t in range(len(init_seq) - 1):
#         output, hidden = net(init_input[t], hidden)
        
#     # Set current input as the last character of the initial string.
#     input = init_input[-1]
    
#     # Predict more characters after the initial string.
#     for t in range(predicted_len):
#         # Get the current output and hidden state.
#         output, hidden = net(input, hidden)
        
#         # Sample from the output as a multinomial distribution.
#         predicted_index = torch.multinomial(output.view(-1).exp(), 1)[0]
        
#         # Add predicted character to the sequence and use it as next input.
#         predicted_char  = all_chars[predicted_index]
#         predicted_seq  += predicted_char
        
#         # Use the predicted character to generate the input of next round.
#         input = seq_to_onehot(predicted_char)[0].to(device)
        for t in range(predicted_len):
            src_mask = generate_square_subsequent_mask(generated.size(0)).to(generated.device)
            output = model(generated, src_mask)
            
            # Select the last time step's output
            next_char_logits = output[-1, 0, :]
            next_char = torch.argmax(next_char_logits, dim=-1).unsqueeze(0).unsqueeze(1)  # Shape: (1, 1)
            
            generated = torch.cat((generated, next_char), dim=0)  # Concatenate along the sequence dimension

        generated_seq = ''.join([all_chars[idx] for idx in generated.squeeze().tolist()])
    return generated_seq


In [17]:
input_dim = len(all_chars)
embed_dim = 128
nhead = 2
nhid = 256
nlayers = 2
output_dim = len(all_chars)
dropout = 0.2

model = TransformerModel(input_dim, embed_dim, nhead, nhid, nlayers, output_dim, dropout)



In [18]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
# epochs = 1
# for epoch in range(epochs):
#     model.train()
#     total_loss = 0.
#     for _ in range(10):  # Adjust the number of batches
#         input, target = get_input_and_target()
#         # Add batch dimension, shape: (seq_len, batch_size=1)
#         input = input.unsqueeze(1)
#         target = target.unsqueeze(1)
        
#         src_mask = generate_square_subsequent_mask(input.size(0)).to(input.device)
        
#         optimizer.zero_grad()
#         output = model(input, src_mask)
        
#         # Reshape output to (seq_len * batch_size, output_dim) and target to (seq_len * batch_size)
#         output = output.view(-1, output_dim)
#         target = target.view(-1)
        
#         loss = criterion(output, target)
#         loss.backward()
#         optimizer.step()
        
#         total_loss += loss.item()
#     print(f'Epoch {epoch+1}, Loss: {total_loss / 100}')

In [19]:
iters = 500
for i in range(iters):  # Adjust the number of batches
    input, target = get_input_and_target()
    # Add batch dimension, shape: (seq_len, batch_size=1)
    input = input.unsqueeze(1)
    target = target.unsqueeze(1)

    src_mask = generate_square_subsequent_mask(input.size(0)).to(input.device)

    optimizer.zero_grad()
    output = model(input, src_mask)

    # Reshape output to (seq_len * batch_size, output_dim) and target to (seq_len * batch_size)
    output = output.view(-1, output_dim)
    target = target.view(-1)

    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    total_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {total_loss / 100}')

NameError: name 'total_loss' is not defined

In [None]:
# # Generate a text sequence starting with a given seed
# seed_text = "The"
# generated_text = generate_text(model, seed_text, max_len=500)
# print(generated_text)
print(eval_step(model, predicted_len=600))