In [2]:
from google.colab import drive
drive.mount("/content/drive")
HOME = 'drive/MyDrive'

Mounted at /content/drive


In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import nltk
nltk.download('punkt')
import random
import pandas as pd
import numpy as np
import os

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [11]:
#Encoding texts to indices
def encode(string, word2index):
  return torch.LongTensor([[word2index[wd] for wd in nltk.word_tokenize(string)]])

#Decoding indices to texts
def decode(vec, index2word):
  return [index2word.get(x) for x in vec]

x = torch.load('drive/MyDrive/saved_dicts')
word2index = x['word2index']
index2word = x['index2word']

In [6]:
class Net_variant(nn.Module):
  def __init__(self, embed_size, input_dim, hidden_dim, batch_first=True, n_layers = 1, dropout = 0.2):
    super(Net_variant, self).__init__()

    self.n_layers = n_layers
    self.hidden_dim = hidden_dim

    #shared embedding layer
    self.embedding_layer = nn.Embedding(num_embeddings=embed_size, embedding_dim=input_dim)
    
    #GRU 1
    self.rnn_layer1 = nn.GRU(input_dim, hidden_dim, batch_first=batch_first, num_layers=n_layers, dropout=dropout, bidirectional=True)
    self.linear1 = nn.Linear(hidden_dim*2, embed_size)

    #GRU 2
    self.rnn_layer2 = nn.GRU(input_dim, hidden_dim, batch_first=batch_first, num_layers=n_layers, dropout=dropout, bidirectional=True)
    self.linear2 = nn.Linear(hidden_dim*2, embed_size)

  def forward(self, x):
    output = self.embedding_layer(x)

    #Randomly selects which GRU layer should be used
    if (random.randrange(2) == 0):
      output1, hidden1 = self.rnn_layer1(output)
      output1 = self.linear1(output1)
      return output1
    else:
      output2, hidden2 = self.rnn_layer2(output)
      output2 = self.linear2(output2)
      return output2

In [7]:
def test_model(model, word2index, index2word, string="", maxlen=25, verbose=False):
  #string is the input
  #maxlen defines max length of the generated txt
  #if verbose==True, shows every loops' input and output. if verbose==False, shows initial input and final output only.

  model.eval()

  eval_input = encode(string, word2index).cuda()
  print("INITIAL INPUT: " + string)

  if verbose:
    print("---")

  for i in range(maxlen):
    output = model(eval_input)
    pred = output.softmax(-1).argmax(-1)

    if verbose:
      print("INPUT: " + " ".join( decode(eval_input.tolist()[0],index2word)))
      print("OUTPUT: " + " ".join( decode(pred[0].tolist(), index2word)))

    eval_input = torch.cat((eval_input,pred[:,-1].unsqueeze(0)), 1)

    if word2index['END'] in eval_input:
      break

  print("GENERATED SEQUENCE: " + " ".join( decode(eval_input.tolist()[0],index2word)))
  print("")

In [14]:
#Hyperparameters
vocab_size = len(word2index)
input_size =  128
hidden_size = 128

model = Net_variant(vocab_size, input_size, hidden_size, batch_first=True)
model.cuda()

x = torch.load("drive/MyDrive/Data/Checkpoint1/" + "CPOINT_FINETUNE-75")
model.load_state_dict(x['model_state_dict'])

  "num_layers={}".format(dropout, num_layers))


<All keys matched successfully>

In [150]:
####################################################################################################################################################################################

# ENTER YOUR INPUT IN HERE
#EXAMPLE: input_string = "I am not a spy."

# Token X no in the dict. means the input can't be used - change the token(word) to something else.

####################################################################################################################################################################################

#Change here
input_string = "Shut up, Wesley."













input_string = input_string.lower()
input_tokens = nltk.word_tokenize(input_string)
flag_generate = True
for token in input_tokens:
  if token not in word2index:
    print("Token [ {} ] not in the dict.".format(token))
    flag_generate = False

if flag_generate:
  test_model(model, word2index, index2word, " ".join(input_tokens))

INITIAL INPUT: shut up , wesley .
GENERATED SEQUENCE: shut up , wesley . the klingon fleet will reach cardassian territory in less than an hour . END

