In [None]:
from google.colab import drive
drive.mount("/content/drive")
HOME = 'drive/MyDrive'

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import random
import pandas as pd
import numpy as np

In [None]:
dataset_quotes = []

import os
with open("drive/MyDrive/Data/PlainSimpleGarak-data.txt", 'r') as f: # open in readonly mode
  while True:
    line = f.readline().strip()
    if line == '':
        break
    else:
      dataset_quotes.append(line)

dataset_startwords = []
for s in dataset_quotes:
  dataset_startwords.append(s.split()[0].lower())

dataset_raw = pd.read_csv('drive/MyDrive/Data/quotes.csv', error_bad_lines=False, engine='python')

dataset_quotes_nonascii = list(dataset_raw["Quote"] )

for s in dataset_quotes_nonascii:
  if (len(s) == len(s.encode())):
    if ("$" not in s and "&" not in s):
      dataset_quotes.append(s)

In [None]:
def preprocess_txt(dataset_quotes, appende=True):
  for i in range(len(dataset_quotes)):
    dataset_quotes[i] = dataset_quotes[i].replace('\t', ' ')
    dataset_quotes[i] = dataset_quotes[i].replace('--', ' ')
    dataset_quotes[i] = dataset_quotes[i].replace(' - ', ' ')
    dataset_quotes[i] = dataset_quotes[i].replace('%', ' percent')
    dataset_quotes[i] = dataset_quotes[i].replace('[', '')
    dataset_quotes[i] = dataset_quotes[i].replace(']', '')
    dataset_quotes[i] = dataset_quotes[i].replace(':', '')
    dataset_quotes[i] = dataset_quotes[i].replace(';', '')
    dataset_quotes[i] = dataset_quotes[i].replace('?', '')
    dataset_quotes[i] = dataset_quotes[i].replace('#', '')
    dataset_quotes[i] = dataset_quotes[i].replace('!', '')
    dataset_quotes[i] = dataset_quotes[i].replace('/', '')
    dataset_quotes[i] = dataset_quotes[i].replace('.', '')
    dataset_quotes[i] = dataset_quotes[i].replace(',', '')
    dataset_quotes[i] = dataset_quotes[i].replace('(', '')
    dataset_quotes[i] = dataset_quotes[i].replace(')', '')
    dataset_quotes[i] = dataset_quotes[i].replace('*', '')
    dataset_quotes[i] = dataset_quotes[i].replace('+', '')
    dataset_quotes[i] = dataset_quotes[i].replace('\"', '')
    #dataset_quotes[i] = dataset_quotes[i].replace('\'', '')
    dataset_quotes[i] = dataset_quotes[i].lower()
    if (appende == True):
      dataset_quotes[i] = dataset_quotes[i] + " <END>"

preprocess_txt(dataset_quotes)
preprocess_txt(dataset_startwords, False)

#dataset_quotes.sort()
dataset_startwords = list(set(dataset_startwords))
#dataset_startwords.sort()

In [None]:
def encode(string, word2index):
  return torch.LongTensor([[word2index[wd] for wd in string.split()]])

def decode(vec, index2word):
  return [index2word.get(x) for x in vec]

class Dataset(torch.utils.data.Dataset):
  def __init__(self, txt, seq_len, word2index):
    self.encoded = [word2index[wd] for wd in txt]
    self.seq_len = seq_len

  def __len__(self):
    return len(self.encoded) - self.seq_len

  def __getitem__(self, index):
    return ( torch.tensor(self.encoded[index:index+self.seq_len]), torch.tensor(self.encoded[index+1:index+self.seq_len+1]) )

txt_quotes = ""
for s in dataset_quotes:
  txt_quotes = txt_quotes + " " + s
txt_quotes = txt_quotes.split()

list_words = list(set(txt_quotes))
list_words.sort()

word2index = {tkn: i for i, tkn in enumerate(list_words, 1)}
word2index['<UNKNOWN>']=0
index2word = {v: k for k, v in word2index.items()}

In [None]:
class Net(nn.Module):
  def __init__(self, embed_size, input_dim, hidden_dim, batch_first=True, n_layers = 1, dropout = 0.2):
    super(Net, self).__init__()

    self.n_layers = n_layers #unused
    self.hidden_dim = hidden_dim

    self.embedding_layer = nn.Embedding(num_embeddings=embed_size, embedding_dim=input_dim)
    #1
    self.rnn_layer = nn.GRU(input_dim, hidden_dim, batch_first=batch_first, num_layers=n_layers, dropout=dropout)
    #self.dropout = nn.Dropout(0.4)
    self.linear = nn.Linear(hidden_dim, embed_size)

  def forward(self, x):
    output = self.embedding_layer(x)

    output, hidden = self.rnn_layer(output)
    #output = self.dropout(output)
    output = self.linear(output)

    return output
    #return output.view(-1, output.size(2))

In [None]:
vocab_size = len(word2index)
input_size =  128
hidden_size = 256

In [None]:
model1 = Net(vocab_size, input_size, hidden_size, batch_first=True)
model1.cuda()
model2 = Net(vocab_size, input_size, hidden_size, batch_first=True)
model2.cuda()

criterion = nn.CrossEntropyLoss()
optimizer1 = optim.Adam(params=model1.parameters())
optimizer2 = optim.Adam(params=model2.parameters())

dataset = Dataset(txt_quotes, 5, word2index)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
epoch_to_load = 0

if epoch_to_load != 0:
  x = torch.load("drive/MyDrive/Data/Checkpoint1/" + "CPOINT-" + str(epoch_to_load))
  model1.load_state_dict(x['model1_state_dict'])
  model2.load_state_dict(x['model2_state_dict'])
  optimizer1.load_state_dict(x['optimizer1_state_dict'])
  optimizer2.load_state_dict(x['optimizer2_state_dict'])
  epoch_to_load = epoch_to_load + 1

In [None]:
def test_model(model, word2index, index2word, string="", maxlen=25, verbose=False):
  model.eval()

  eval_input = encode(string, word2index).cuda()
  print("INITIAL INPUT: " + string)

  if verbose:
    print("---")

  for i in range(maxlen):
    output = model(eval_input)
    pred = output.softmax(-1).argmax(-1)

    if verbose:
      print("INPUT: " + " ".join( decode(eval_input.tolist()[0],index2word)))
      print("OUTPUT: " + " ".join( decode(pred[0].tolist(), index2word)))

    eval_input = torch.cat((eval_input,pred[:,-1].unsqueeze(0)), 1)

    if word2index['<END>'] in eval_input:
      break

  print("GENERATED SEQUENCE: " + " ".join( decode(eval_input.tolist()[0],index2word)))
  print("")

In [None]:
for epoch in range(epoch_to_load, 1001):
  model1.train()
  model2.train()

  for batch, (input, target) in enumerate(dataloader):
    optimizer1.zero_grad()
    output1 = model1(input.cuda())
    loss1 = criterion(output1.transpose(1, 2), target.cuda())
    loss1.backward()
    optimizer1.step()

    optimizer2.zero_grad()
    output2 = model2(input.cuda())
    loss2 = criterion(output2.transpose(1, 2), target.cuda())
    loss2.backward()
    optimizer2.step()

  model1.eval()
  model2.eval()
  print("Epoch {:02d} / 1001 Loss1 {:.4f} Loss2 {:.4f}".format(epoch+1, loss1, loss2))
  
  print("===========================================================================")

  print("TARGET: " + " ".join( decode(target[0].tolist(),index2word)))
  print("INPUT: " + " ".join( decode(input[0].tolist(),index2word)))
  pred = output1[0].softmax(-1).argmax(-1)
  print("PREDICTION: " + " ".join(decode(pred.tolist(),index2word)))

  print("---------------------------------------------------------------------------")

  print("TARGET: " + " ".join( decode(target[0].tolist(),index2word)))
  print("INPUT: " + " ".join( decode(input[0].tolist(),index2word)))
  pred = output2[0].softmax(-1).argmax(-1)
  print("PREDICTION: " + " ".join(decode(pred.tolist(),index2word)))

  print("===========================================================================")

  randword = dataset_startwords[random.randrange(0,len(dataset_startwords))]

  test_model(model1, word2index, index2word, randword)
  test_model(model1, word2index, index2word, "i am plain simple garak " + randword)

  test_model(model2, word2index, index2word, randword)
  test_model(model2, word2index, index2word, "i am plain simple garak " + randword)

  print("===========================================================================")

  if (epoch % 5 == 0):
    #torch.save(model.state_dict(), "drive/MyDrive/Data/" + "Checkpoint-" + str(epoch) )
    torch.save(
        {'model1_state_dict': model1.state_dict(),
         'model2_state_dict': model2.state_dict(),
         'optimizer1_state_dict': optimizer1.state_dict(),
         'optimizer2_state_dict': optimizer2.state_dict(),},
        'drive/MyDrive/Data/Checkpoint1/' + "CPOINT-" + str(epoch)
               )

  print()