In [1]:
from dataloader import *
from VAE import *
from scores import *

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.ticker as ticker
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import random

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare data

In [3]:
train_vocab = load_data('./data/train.txt')
test_vocab = load_data('./data/test.txt')

## Get different tense pairs
### !Basically unused in conditional VAE training!

In [4]:
def get_tense_paris(train_vocab, input_tense, target_tense):
    pairs = []

    for vocabs in train_vocab:
        pairs.append((vocabs[input_tense],vocabs[target_tense]))
        
    return pairs  

# Simple Present -> Third Person
train_st_tp  = get_tense_paris(train_vocab, 0, 1)
# Simple Present -> Present Progressive
train_st_pp  = get_tense_paris(train_vocab, 0, 2)
# Simple Present -> Past
train_st_past  = get_tense_paris(train_vocab, 0, 3)

# Train VAE

In [5]:
vocab_size = 28 #The number of vocabulary
SOS_token = 0
EOS_token = vocab_size-1

## Setting hyperparameters

In [6]:
#----------Hyper Parameters----------#
hidden_size = 256
latent_size = 64
teacher_forcing_ratio = 0.75
empty_input_ratio = 0.1
KLD_weight = 0.0
lr = 0.01

In [7]:
def seq_from_str(target):
    ord_a = ord('a')
    seq = [ord(c) - ord_a + 1 for c in target]
    
    return seq

def str_from_tensor(target):
    seq = ''
    for output in target:
        _, c = output.topk(1)
        seq += chr(c+ord('a')-1)

    return seq

## Use KL annealing

In [8]:
def KL_annealing(current_iter, policy = 'mono', reach_max = 3000, period = 6000):
    if policy == 'mono':
        beta = 1 if current_iter >= reach_max else current_iter/reach_max
    elif policy == 'cyclical':
        beta = 1 if current_iter%period >= reach_max else (current_iter%period)/reach_max
    else:
        raise ValueError
        
    return beta

## Inference 4 tense using simple present (for BLEU-4 score)

In [9]:
def infer_by_simple(vae_model, data_tuple):
    pred_tuple = []
    
    vae_model.eval()
    
    with torch.no_grad():
        for i in range(4):
            input_tense = 0  # Input: simple present
            target_tense = i # Target: 4 tense results
            input_seq, target_seq = (seq_from_str(data_tuple[input_tense]),seq_from_str(data_tuple[target_tense])) 
            
            # Initialize hidden feature
            hidden = torch.zeros(1, 1, hidden_size, device=device)

            result, mu, logvar = vae_model(input_seq, hidden, input_tense, target_tense)
            
            pred_seq = str_from_tensor(result)
            pred_tuple.append(pred_seq[:-1])
            
    return pred_tuple

## Training Functions

In [10]:
def train_condVAE(vae_model, input_seq, input_cond, target_seq, target_cond, use_teacher_forcing, optimizer, \
                  criterion_CE, criterion_KLD, kl_annealing_beta = 1):    
    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Initialize hidden feature
    hidden = torch.zeros(1, 1, hidden_size, device=device)
        
    # Run model
    optimizer.zero_grad()
    if use_teacher_forcing:
        # input_cond is encoder condition; targer_cond is decoder condition
        result, mu, logvar = vae_model(input_seq, hidden, input_cond, target_cond, use_teacher_forcing, target_seq)
    else:
        result, mu, logvar = vae_model(input_seq, hidden, input_cond, target_cond, use_teacher_forcing, None)
            
            
    # Ground truth should have EOS in the end
    target_seq.append(EOS_token)
    
    # Calculate loss
    # First, we should strim the sequences by the length of smaller one
    min_len = min(len(target_seq),len(result))
        
    # hat_y need not to do one-hot encoding
    hat_y = result[:min_len]
    y = torch.tensor(target_seq[:min_len], device=device)
        
    ce_loss = criterion_CE(hat_y, y)
    kld_loss = criterion_KLD(mu, logvar)
    kld_loss = kl_annealing_beta * kld_loss # KL annealing
    
    loss = ce_loss + kld_loss
        
    loss.backward()
    optimizer.step()
    
    return ce_loss.item(), kld_loss.item(), hat_y

In [16]:
def trainIter_condVAE(vae_model, data, n_iters, print_every=1000, save_every=1000, record_every=1000,
                      learning_rate=0.01, teacher_forcing_ratio = 1.0, 
                      optimizer = None, scheduler = None,
                      criterion_CE = VAE_Loss_CE, criterion_KLD = VAE_Loss_KLD,
                      date = '', kl_annealing = 'mono'):
    '''
        data: A list of 4-tuple
              the tense order should be : (simple present, third person, present progressive, past)
    '''
    loss_list = []
    ce_loss_list = []
    kld_loss_list = []
    bleu_list = []
  
    # Check optimizer; default: SGD
    if optimizer is None:
        optimizer = optim.SGD(vae_model.parameters(), lr=learning_rate)
    
    for i in range(n_iters): 
        
        # Randomly generate training pairs from data
        chosen_data = random.choice(data)
        input_tense = random.randint(0,3) # Draw input tense
        #target_tense = random.randint(0,3) # Draw target tense
        target_tense = input_tense
        input_seq = seq_from_str(chosen_data[input_tense])
        target_seq = seq_from_str(chosen_data[target_tense])                  
        
        use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
        
        
        # Calculate BLEU-4 score
        # Should execute before updating the model
        if (i+1) % record_every == 0 or (i+1) % print_every == 0:
            pred = infer_by_simple(vae_model, chosen_data)
            bleu_score = compute_bleu(pred, chosen_data)
        
        # KL annealing's beta
        beta = KL_annealing(i, policy=kl_annealing)
        
        # Training
        
        vae_model.train()
        ce_loss, kld_loss, hat_y = train_condVAE(vae_model, input_seq, input_tense, target_seq, target_tense,\
                             use_teacher_forcing, optimizer, criterion_CE, criterion_KLD, beta)
    
        # Loss
        loss = ce_loss + kld_loss
        
        
        
        # Convert output to str
        pred_seq = str_from_tensor(hat_y)
        
        if scheduler is not None:
            scheduler.step(loss)
        
        if (i+1) % record_every == 0:
            loss_list.append(loss)
            ce_loss_list.append(ce_loss)
            kld_loss_list.append(kld_loss)
            bleu_list.append(bleu_score)
            
        if (i+1) % print_every == 0:
            print('-----------------')
            print('Iter %d: loss = %.4f' % (i+1, loss))
            print('ce_loss = ', ce_loss)
            print('kld_loss = ', kld_loss)
            print('    ==================')
            print('    pred = ', pred)
            print('    chosen_data = ', chosen_data)
            print('    BLEU-4 score = ', bleu_score)
            print('    ==================')
            print('input_seq = ', chosen_data[input_tense])
            print('pred_seq = ', pred_seq)
            print('target_seq = ', chosen_data[target_tense])
            
        if (i+1) % save_every == 0:
            torch.save(vae_model,'./models/condVAE_'+str(i+1)+date)
    
    return loss_list, ce_loss_list, kld_loss_list, bleu_list

In [17]:
my_vae = CondVAE(vocab_size, hidden_size, vocab_size, teacher_forcing_ratio).to(device)

In [18]:
optimizer = optim.SGD(my_vae.parameters(), lr=lr)

## Train

In [None]:
loss_list, ce_loss_list, kld_loss_list, bleu_list = \
            trainIter_condVAE(my_vae, train_vocab, n_iters=2000000, \
                           print_every=1000, save_every=100000, record_every=500,\
                           learning_rate=lr,teacher_forcing_ratio=teacher_forcing_ratio, \
                           optimizer= optimizer, criterion_CE = VAE_Loss_CE, \
                           criterion_KLD = VAE_Loss_KLD,date = '_0813_2105')

-----------------
Iter 1000: loss = 2.8532
ce_loss =  2.852135419845581
kld_loss =  0.0010822410695254803
    pred =  ['sesesesesesesesesesesese', 'sesesesesesesesesesesese', 'sesesesesesesesesesesese', 'sesesesesesesesesesesese']
    chosen_data =  ['drag', 'drags', 'dragging', 'dragged']
    BLEU-4 score =  0
input_seq =  dragged
pred_seq =  seeeees
target_seq =  dragged
-----------------
Iter 2000: loss = 2.5671
ce_loss =  2.5646588802337646
kld_loss =  0.002437703311443329
    pred =  ['sesesesesesesesesesesese', 'sesesesesesesesesesesese', 'sesesesesesesesesesesese', 'sesesesesesesesesesesese']
    chosen_data =  ['arouse', 'arouses', 'arousing', 'aroused']
    BLEU-4 score =  0
input_seq =  arouse
pred_seq =  sreene
target_seq =  arouse
-----------------
Iter 3000: loss = 2.1872
ce_loss =  2.186622142791748
kld_loss =  0.0005928387399762869
    pred =  ['sesesesesesesesesesesese', 'sesesesesesesesesesesese', 'sesesesesesesesesesesese', 'sesesesesesesesesesesese']
    chosen_data 

-----------------
Iter 24000: loss = 3.0130
ce_loss =  3.0123958587646484
kld_loss =  0.0005922913551330566
    pred =  ['consee', 'consee', 'consee', 'consee']
    chosen_data =  ['object', 'objects', 'objecting', 'objected']
    BLEU-4 score =  0
input_seq =  object
pred_seq =  cnlert
target_seq =  object
-----------------
Iter 25000: loss = 1.8364
ce_loss =  1.8363333940505981
kld_loss =  5.9932470321655273e-05
    pred =  ['searended', 'consended', 'searended', 'seetere']
    chosen_data =  ['extract', 'extracts', 'extracting', 'extracted']
    BLEU-4 score =  0
input_seq =  extracted
pred_seq =  sxpeetked
target_seq =  extracted
-----------------
Iter 26000: loss = 2.4252
ce_loss =  2.4243199825286865
kld_loss =  0.0008727014064788818
    pred =  ['snounted', 'surese', 'snounted', 'snounted']
    chosen_data =  ['give', 'gives', 'giving', 'gave']
    BLEU-4 score =  0
input_seq =  gave
pred_seq =  srre
target_seq =  gave
-----------------
Iter 27000: loss = 1.9796
ce_loss =  1.979

-----------------
Iter 47000: loss = 1.9183
ce_loss =  1.9179389476776123
kld_loss =  0.0003822147846221924
    pred =  ['serre', 'serre', 'serre', 'serre']
    chosen_data =  ['suffuse', 'suffuses', 'suffusing', 'suffused']
    BLEU-4 score =  0
input_seq =  suffuse
pred_seq =  serfere
target_seq =  suffuse
-----------------
Iter 48000: loss = 1.6734
ce_loss =  1.6732635498046875
kld_loss =  9.736418724060059e-05
    pred =  ['shantered', 'shantered', 'shantered', 'shantered']
    chosen_data =  ['bounce', 'bounces', 'bouncing', 'bounced']
    BLEU-4 score =  0
input_seq =  bounces
pred_seq =  seunted
target_seq =  bounces
-----------------
Iter 49000: loss = 1.6302
ce_loss =  1.6298900842666626
kld_loss =  0.0003027915954589844
    pred =  ['sumplied', 'souttin', 'souttin', 'souttin']
    chosen_data =  ['expend', 'expends', 'expending', 'expended']
    BLEU-4 score =  0
input_seq =  expend
pred_seq =  sxpert
target_seq =  expend
-----------------
Iter 50000: loss = 1.7426
ce_loss = 

In [None]:
plt.plot(kld_loss_list)

# Evaluation

In [None]:
def val(vae_model, data_pairs, num_eval_data ,criterion_CE = VAE_Loss_CE, criterion_KLD = VAE_Loss_KLD):
    loss_list = []
    ce_loss_list = []
    kld_loss_list = []
    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    vae_model.eval()
    
    with torch.no_grad():
        for i in range(num_eval_data):
            # Seperate pair for input# Randomly generate testing pairs from data
            chosen_data = random.choice(data)
            input_tense = random.randint(0,3) # Draw input tense
            target_tense = random.randint(0,3) # Draw target tense
            input_seq, target_seq = (seq_from_str(chosen_data[input_tense]),seq_from_str(chosen_data[target_tense])) 
            
            # Initialize hidden feature
            hidden = torch.zeros(1, 1, hidden_size, device=device)

            result, mu, logvar = vae_model(input_seq, hidden, input_tense, target_tense)

            # Ground truth should have EOS in the end
            target_seq.append(EOS_token)

            # Calculate loss
            # First, we should strim the sequences by the length of smaller one
            min_len = min(len(target_seq),len(result))
            hat_y = result[:min_len]
            y = torch.tensor(target_seq[:min_len], device=device)

            ce_loss = criterion_CE(hat_y, y)
            kld_loss = criterion_KLD(mu, logvar)
            kld_loss = kld_loss # KL annealing

            loss = ce_loss + kld_loss
            
            loss_list.append(loss)
            ce_loss_list.append(ce_loss)
            kld_loss_list.append(kld_loss)
            

            # Convert predicted result into str
            pred_seq = str_from_tensor(hat_y)
            print('-----------------')
            print('loss = ', loss)
            print('input_seq = ', chosen_data[input_tense])
            print('pred_seq = ', pred_seq)
            print('target_seq = ', chosen_data[target_tense])
            

    return loss_list, ce_loss_list, kld_loss_list

In [None]:
val(my_vae, train_vocab, num_eval_data= 200, criterion = VAE_Loss)