In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np

import pytorch_lightning as pl
from dotmap import DotMap

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PAD_ID = 1

In [10]:
class RNN(pl.LightningModule):
    def __init__(self, input_size, output_size, hidden_size, num_layers, batch_size=4):
        super(RNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.input_size = input_size
        
        self.embedding = nn.Embedding(input_size, output_size)
        self.rnn = nn.LSTM(input_size = input_size, hidden_size=hidden_size, num_layers = num_layers)
        self.decoder = nn.Linear(hidden_size, output_size)
        
        self.criterion = nn.NLLLoss(ignore_index = PAD_ID)

        
    def forward(self, input_seq, hidden_state):
        embedding  = self.embedding(input_seq)
        output, hidden_state = self.rnn(embedding, hidden_state)
        output = output.reshape(output.shape[0],-1)
        output = self.decoder(output)
        
        return output, (hidden_state[0].detach(), hidden_state[1].detach())
    
    def init_hidden(self, batch_size):
        h = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        c = torch.zeros(self.num_layers, batch_size, self.hidden_size).to(device)
        
        return h,c
    
    def training_step(self, batch, batch_idx):
        src, tgt, lengths = batch
        
        hidden_state = self.init_hidden(src.shape[0])
        loss = 0
        chunk_len = src.shape[1]
        for j in range(chunk_len):
            output, (h,c ) = self.rnn(src[:,j],hidden_state)
            loss += self.criterion(output, target[:,j])
        
        loss = loss.item()/ chunk_len
        
        self.log('train_loss', loss)
        return {'loss': loss}
    
    def generate(self, initial_char = 'A', predict_len = 15, temperature=0.85):
        hidden, cell = self.init_hidden(batch_size = 1)
        initial_input = self.char2tensor(initial_char)
        predicted_str = initial_char
        
        for p in range(len(inital_char)-1):
            _, (hidden, cell) = self.rnn(initial_input[p].view(1).to(device), hidden, cell)
            
        last_char = initial_input[-1]
        
        for p in range(predict_len):
            output, (hidden , cell) = self.rnn(last_char.view(1).to(device), hidden, cell)
            output_dist = output.data.view(-1).div(temperature).exp()
            top_char = torch.multinomial(output_dist, 1)[0]
            
            # convert back to string
            predicted_char = all_chars[top_char]
            predicted += predicted_char
            last_char  = top_char
            
        return predicted

In [None]:
model = RNN()

In [11]:
trainer = pl.Trainer(fast_dev_run=True)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores
Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).


In [None]:
trainer.fit()