In [1]:
from __future__ import print_function
from math import ceil
import numpy as np
import sys
import pdb

import torch
import torch.optim as optim
import torch.nn as nn

import generator
import discriminator
import helpers


CUDA = False
VOCAB_SIZE = 15000
MAX_SEQ_LEN = 40
START_LETTER = 0
BATCH_SIZE = 32
MLE_TRAIN_EPOCHS = 10
ADV_TRAIN_EPOCHS = 5
POS_NEG_SAMPLES = 100000

GEN_EMBEDDING_DIM = 32
GEN_HIDDEN_DIM = 32
DIS_EMBEDDING_DIM = 64
DIS_HIDDEN_DIM = 64

oracle_samples_path = './oracle_samples.trc'
oracle_state_dict_path = './oracle_EMBDIM32_HIDDENDIM32_VOCAB5000_MAXSEQLEN20.trc'
pretrained_gen_path = './gen_MLEtrain_EMBDIM32_HIDDENDIM32_VOCAB5000_MAXSEQLEN20.trc'
pretrained_dis_path = './dis_pretrain_EMBDIM_64_HIDDENDIM64_VOCAB5000_MAXSEQLEN20.trc'


def train_generator_MLE(gen, gen_opt, real_data_samples, epochs):
    """
    Max Likelihood Pretraining for the generator
    """
    for epoch in range(epochs):
        print('epoch %d : ' % (epoch + 1), end='')
        sys.stdout.flush()
        total_loss = 0

        for i in range(0, POS_NEG_SAMPLES, BATCH_SIZE):
            inp, target = helpers.prepare_generator_batch(real_data_samples[i:i + BATCH_SIZE], start_letter=START_LETTER,
                                                          gpu=CUDA)
            gen_opt.zero_grad()
            loss = gen.batchNLLLoss(inp, target)
            loss.backward()
            gen_opt.step()

            total_loss += loss.data.item()

            if (i / BATCH_SIZE) % ceil(
                            ceil(POS_NEG_SAMPLES / float(BATCH_SIZE)) / 10.) == 0:  # roughly every 10% of an epoch
                print('.', end='')
                sys.stdout.flush()

        # each loss in a batch is loss per sample
        total_loss = total_loss / ceil(POS_NEG_SAMPLES / float(BATCH_SIZE)) / MAX_SEQ_LEN


        print(' average_train_NLL = %.4f' % (total_loss))


def train_generator_PG(gen, gen_opt, validation_data_samples, dis, num_batches):
    """
    The generator is trained using policy gradients, using the reward from the discriminator.
    Training is done for num_batches batches.
    """

    for batch in range(num_batches):
        s = gen.sample(BATCH_SIZE*2)        # 64 works best
        inp, target = helpers.prepare_generator_batch(s, start_letter=START_LETTER, gpu=CUDA)
        rewards = dis.batchClassify(target)

        gen_opt.zero_grad()
        pg_loss = gen.batchPGLoss(inp, target, rewards)
        pg_loss.backward()
        gen_opt.step()

    validation_loss=0
    VAL_SIZE=validation_data_samples.shape[0]
    for i in range(0, VAL_SIZE, BATCH_SIZE):
        inp, target = helpers.prepare_generator_batch(validation_data_samples[i:i + BATCH_SIZE], start_letter=START_LETTER,
                                                      gpu=CUDA)
        gen_opt.zero_grad()
        loss = gen.batchNLLLoss(inp, target)
        validation_loss+= loss
    #helpers.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
                   #                                start_letter=START_LETTER, gpu=CUDA)

    print(' validation_loss = %.4f' % validation_loss)


def train_discriminator(discriminator, dis_opt, real_data_samples, generator, trainset, d_steps, epochs):
    """
    Training the discriminator on real_data_samples (positive) and generated samples from generator (negative).
    Samples are drawn d_steps times, and the discriminator is trained for epochs epochs.
    """

    # generating a small validation set before training (using oracle and generator)
    len_sampled=10
    perm=np.random.permutation(trainset.shape[0])
    pos_val=torch.tensor(trainset[perm[:len_sampled]])
    neg_val = generator.sample(len_sampled)
    val_inp, val_target = helpers.prepare_discriminator_data(pos_val, neg_val, gpu=CUDA)

    for d_step in range(d_steps):
        s = helpers.batchwise_sample(generator, POS_NEG_SAMPLES, BATCH_SIZE)
        dis_inp, dis_target = helpers.prepare_discriminator_data(real_data_samples, s, gpu=CUDA)
        for epoch in range(epochs):
            print('d-step %d epoch %d : ' % (d_step + 1, epoch + 1), end='')
            sys.stdout.flush()
            total_loss = 0
            total_acc = 0

            for i in range(0, 2 * POS_NEG_SAMPLES, BATCH_SIZE):
                inp, target = dis_inp[i:i + BATCH_SIZE], dis_target[i:i + BATCH_SIZE]
                dis_opt.zero_grad()
                out = discriminator.batchClassify(inp)
                loss_fn = nn.BCELoss()
                loss = loss_fn(out, target)
                loss.backward()
                dis_opt.step()

                total_loss += loss.data.item()
                total_acc += torch.sum((out>0.5)==(target>0.5)).data.item()

                if (i / BATCH_SIZE) % ceil(ceil(2 * POS_NEG_SAMPLES / float(
                        BATCH_SIZE)) / 10.) == 0:  # roughly every 10% of an epoch
                    print('.', end='')
                    sys.stdout.flush()

            total_loss /= ceil(2 * POS_NEG_SAMPLES / float(BATCH_SIZE))
            total_acc /= float(2 * POS_NEG_SAMPLES)

            val_pred = discriminator.batchClassify(val_inp)
            print(' average_loss = %.4f, train_acc = %.4f, val_acc = %.4f' % (
                total_loss, total_acc, torch.sum((val_pred>0.5)==(val_target>0.5)).data.item()/200.))


In [2]:
def loadData(filepath):
    ret=[]
    with open(filepath,'r') as fin:
        for line in fin:
            ret.append([int(token) for token in line.split()])
    return np.array(ret)
trainset=loadData('./dataset/train.vec')
validationset=loadData('./dataset/valid.vec')
testset=loadData('./dataset/test.vec')

trainset_tensor=torch.tensor(trainset)
validationset_tensor=torch.tensor(validationset)
oracle=None

gen = generator.Generator(GEN_EMBEDDING_DIM, GEN_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)
dis = discriminator.Discriminator(DIS_EMBEDDING_DIM, DIS_HIDDEN_DIM, VOCAB_SIZE, MAX_SEQ_LEN, gpu=CUDA)

if CUDA:
    #oracle = oracle.cuda()
    gen = gen.cuda()
    dis = dis.cuda()
    oracle_samples = oracle_samples.cuda()

In [None]:
# GENERATOR MLE TRAINING
print('Starting Generator MLE Training...')
gen_optimizer = optim.Adam(gen.parameters(), lr=1e-2)
train_generator_MLE(gen, gen_optimizer, trainset_tensor, MLE_TRAIN_EPOCHS)

# torch.save(gen.state_dict(), pretrained_gen_path)
# gen.load_state_dict(torch.load(pretrained_gen_path))

# PRETRAIN DISCRIMINATOR
print('\nStarting Discriminator Training...')
dis_optimizer = optim.Adagrad(dis.parameters())
train_discriminator(dis, dis_optimizer, trainset_tensor, gen, trainset, 50, 3)

# torch.save(dis.state_dict(), pretrained_dis_path)
# dis.load_state_dict(torch.load(pretrained_dis_path))

Starting Generator MLE Training...
epoch 1 : .....

In [4]:
# ADVERSARIAL TRAINING
print('\nStarting Adversarial Training...')
oracle_loss = -1 #helpers.batchwise_oracle_nll(gen, oracle, POS_NEG_SAMPLES, BATCH_SIZE, MAX_SEQ_LEN,
#                                           start_letter=START_LETTER, gpu=CUDA)
print('\nInitial Oracle Sample Loss : %.4f' % oracle_loss)

for epoch in range(ADV_TRAIN_EPOCHS):
    print('\n--------\nEPOCH %d\n--------' % (epoch+1))
    # TRAIN GENERATOR
    print('\nAdversarial Training Generator : ', end='')
    sys.stdout.flush()
    train_generator_PG(gen, gen_optimizer, validationset_tensor, dis, 1)

    # TRAIN DISCRIMINATOR
    print('\nAdversarial Training Discriminator : ')
    train_discriminator(dis, dis_optimizer, trainset_tensor, gen, trainset, 5, 3)


Starting Adversarial Training...

Initial Oracle Sample Loss : -1.0000

--------
EPOCH 1
--------

Adversarial Training Generator :  validation_loss = 306.1471

Adversarial Training Discriminator : 
d-step 1 epoch 1 : .... average_loss = 0.0004, train_acc = 1.0000, val_acc = 0.1000
d-step 1 epoch 2 : .... average_loss = 0.0004, train_acc = 1.0000, val_acc = 0.1000
d-step 1 epoch 3 : .... average_loss = 0.0004, train_acc = 1.0000, val_acc = 0.1000
d-step 2 epoch 1 : .... average_loss = 0.0004, train_acc = 1.0000, val_acc = 0.1000
d-step 2 epoch 2 : .... average_loss = 0.0004, train_acc = 1.0000, val_acc = 0.1000
d-step 2 epoch 3 : .... average_loss = 0.0004, train_acc = 1.0000, val_acc = 0.1000
d-step 3 epoch 1 : .... average_loss = 0.0086, train_acc = 0.9922, val_acc = 0.1000
d-step 3 epoch 2 : .... average_loss = 0.0007, train_acc = 1.0000, val_acc = 0.1000
d-step 3 epoch 3 : .... average_loss = 0.0005, train_acc = 1.0000, val_acc = 0.1000
d-step 4 epoch 1 : .... average_loss = 0.000