In [1]:
import random
import pickle

import torch
import torch.nn as nn
import torch.optim as optim

from train import run
from model import EncoderRNN, LuongAttnDecoderRNN
from chat import GreedySearchDecoder, chat

In [2]:
def run_train(device, encoder_name="GRU", decoder_name="GRU", encoder_direction=2, opt="ADAM", 
              EPOCH_NUM=50, DROPOUT=0.1, CLIP=10.0, LR=0.001, WD=1e-5, HIDDEN_SIZE=300, 
              ENCODER_N_LAYERS=2, DECODER_N_LAYERS=2, BATCH_SIZE=64, BIDIRECTION=True, attn_model="dot"):

    phase = {
        "train": {"pairs": []},
    }

    with open("formatted_movie_QR_lines_train.txt", "r") as file_obj:
        for line in file_obj:
            phase["train"]["pairs"].append(line.split("\n")[0].split("\t"))
    with open('voc.pickle', "rb") as f:
        phase["train"]["voc"] = pickle.load(f)

    # Shuffle both sets ONCE before the entire training
    random.seed(1)  # seed can be any number
    random.shuffle(phase["train"]["pairs"])

    print('Building training set encoder and decoder ...')
    # Initialize word embeddings for both encoder and decoder
    embedding = nn.Embedding(phase["train"]["voc"].num_words, HIDDEN_SIZE).to(device)

    # Initialize encoder & decoder models
    encoder = EncoderRNN(HIDDEN_SIZE, embedding, ENCODER_N_LAYERS, DROPOUT, gate=encoder_name,
                          bidirectional=BIDIRECTION)
    decoder = LuongAttnDecoderRNN(attn_model, embedding, HIDDEN_SIZE,
                                  phase["train"]["voc"].num_words, DECODER_N_LAYERS, DROPOUT, gate=decoder_name)

    # Use appropriate device
    encoder = encoder.to(device)
    decoder = decoder.to(device)
    encoder.train()
    decoder.train()
    print('Models built and ready to go!')

    # Initialize optimizers
    print('Building optimizers ...')
    if opt == "ADAM":
        encoder_optimizer = optim.Adam(encoder.parameters(), lr=LR, weight_decay=WD)
        decoder_optimizer = optim.Adam(decoder.parameters(), lr=LR, weight_decay=WD)
    elif opt == "SGD":
        encoder_optimizer = optim.SGD(encoder.parameters(), lr=LR)
        decoder_optimizer = optim.SGD(decoder.parameters(), lr=LR)
    else:
        raise ValueError("Wrong optimizer type has been given as an argument.")

    # If you have cuda, configure cuda to call
    for optimizer in [encoder_optimizer, decoder_optimizer]:
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.cuda()

    print("Starting Training!")
    run(encoder, decoder, encoder_optimizer, decoder_optimizer, EPOCH_NUM, BATCH_SIZE, CLIP, phase)
    
    return encoder, decoder, phase["train"]["voc"]

In [3]:
# keep encoder and decoder the same
encoder = "GRU" # GRU, LSTM, MogLSTM
decoder = "GRU" # GRU, LSTM, MogLSTM
epochs = 10

# Get device object
USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
print(device)

encoder, decoder, voc = run_train(encoder_name=encoder, decoder_name=decoder, EPOCH_NUM=epochs, device=device)

cuda
Building training set encoder and decoder ...
Models built and ready to go!
Building optimizers ...
Starting Training!
Number of total pairs used for [TRAIN]: 42532
Number of batches used for a [TRAIN] epoch: 664
Training for 10 epochs...
[TRAIN] Epoch: 1 Loss: 4.80516 BLEU score: 0.08961 33.8 s
[TRAIN] Epoch: 2 Loss: 4.48559 BLEU score: 0.09556 33.68 s
[TRAIN] Epoch: 3 Loss: 4.22112 BLEU score: 0.11438 33.62 s
[TRAIN] Epoch: 4 Loss: 3.81435 BLEU score: 0.14507 33.59 s
[TRAIN] Epoch: 5 Loss: 3.31764 BLEU score: 0.19401 33.62 s
[TRAIN] Epoch: 6 Loss: 2.82097 BLEU score: 0.25004 33.66 s
[TRAIN] Epoch: 7 Loss: 2.37608 BLEU score: 0.30486 33.73 s
[TRAIN] Epoch: 8 Loss: 2.0121 BLEU score: 0.3474 33.59 s
[TRAIN] Epoch: 9 Loss: 1.72157 BLEU score: 0.38008 33.72 s
[TRAIN] Epoch: 10 Loss: 1.49852 BLEU score: 0.40151 33.6 s


In [4]:
def run_chat(encoder, decoder, device, voc):
  encoder = encoder.to(device)
  decoder = decoder.to(device)

  # Initialize search module
  searcher = GreedySearchDecoder(encoder, decoder)
  chat(searcher, voc)

In [5]:
run_chat(encoder, decoder, device, voc)

> who are you?
Bot: i m cynthia . . .
> q
