In [1]:
import os 
import sys
from tqdm import tqdm
import importlib
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

module_path = '/home/lun/project-basileus/seq-gan/sgtlstm'
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
if 'sgtlstm' in sys.modules:
    importlib.reload(sys.modules['sgtlstm'])

from sgtlstm.utils import load_fixed_length_sequence_from_pickle, create_dataset
from sgtlstm.SeqGan import build_G, build_D
from sgtlstm.TimeLSTM import TimeLSTM0, TimeLSTM1, TimeLSTM2, TimeLSTM3

## Load data 

In [3]:
valid_data_path = '/home/lun/project-basileus/seq-gan/data/fixed_length_with_init_token/valid_sequences.pickle'
invalid_data_path = '/home/lun/project-basileus/seq-gan/data/fixed_length_with_init_token/invalid_sequences.pickle'

valid_event_type_seqs, valid_timestamp_seqs = load_fixed_length_sequence_from_pickle(valid_data_path, to_timedelta=True, end_token=0)
invalid_event_type_seqs, invalid_timestamp_seqs = load_fixed_length_sequence_from_pickle(invalid_data_path, to_timedelta=True, end_token=0)

## Global Variables 

In [4]:
BATCH_SIZE = 128
T = 10 + 1
VOCAB = ['END/PADDING', 'INIT', 'start', 'click', 'install']
EVENT_VOCAB_DIM = len(VOCAB)
EMB_DIM = 5
HIDDEN_DIM = 64
K_MIST = 7

END_TOKEN = 0
MAX_TIME = 1024

## Load pretrained model

In [None]:
discriminator = build_D(
    T = T,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
    k_mixt = K_MIST
)
discriminator.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

D_save_path = './gan_model_weights/pretrained_1000_disc.h5'
discriminator.load_weights(D_save_path)

In [None]:
generator = build_G(
    T = T,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
    k_mixt = K_MIST,
    return_sequence=False,
)
generator.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

G_save_path = './gan_model_weights/pretrained_1000_gen.h5'
generator.load_weights(G_save_path)

## Train G and D

In [None]:
train_et = valid_event_type_seqs
train_ts = valid_timestamp_seqs

train_labels = np.ones((valid_event_type_seqs.shape[0], 1))

train_features = (train_et, train_ts)
N_train = train_et.shape[0]

In [None]:
optimizer = tf.keras.optimizers.Adam(lr=1e-3)
# optimizer = SGD(learning_rate=1e-4)

dataset = create_dataset(train_features,
                         train_labels,
                         batch_size=BATCH_SIZE,
                         epochs=3,
                         buffer_size=N_train)

gen_token_loss_history = []
gen_gaussian_loss_history = []

disc_token_loss_history = []
disc_gaussian_loss_history = []

WEIGHT_GAUSSIAN_LOSS = 1e-2
_G_STEPS = 2
_D_STEPS = 1

In [None]:
_TOTAL_STEPS = 240

step = 0

for features_batch, _ in tqdm(dataset.take(_TOTAL_STEPS)):
    step += 1
    print('Training Step:', step)

    # train the generator
    for _ in range(_G_STEPS):
        gen_token_loss, gen_gaussian_loss = train_generator(generator, discriminator, T, event_vocab_dim=EVENT_VOCAB_DIM,  verbose=False, weight_gaussian_loss=WEIGHT_GAUSSIAN_LOSS)
        gen_token_loss_history.append(gen_token_loss.numpy())
        gen_gaussian_loss_history.append(gen_gaussian_loss.numpy())        
    
    # train the discriminator
    for _ in range(_D_STEPS):
        disc_token_loss, disc_gaussian_loss = train_discriminator(features_batch, generator, discriminator, T, event_vocab_dim=EVENT_VOCAB_DIM, verbose=False)              
        disc_token_loss_history.append(disc_token_loss.numpy())
        disc_gaussian_loss_history.append(disc_gaussian_loss.numpy())

## Loss over training

In [None]:
x = range(len(gen_token_loss_history))
plt.figure(dpi=100)
plt.plot(x, gen_token_loss_history)
plt.title('Generator Toke Loss History')
plt.xlabel('training steps')

x = range(len(gen_gaussian_loss_history))
plt.figure(dpi=100)
plt.plot(x, gen_gaussian_loss_history)
plt.title('Generator Gaussian Loss History')
plt.xlabel('training steps')

In [None]:
x = range(len(disc_token_loss_history))
plt.figure(dpi=100)
plt.plot(x, disc_token_loss_history)
plt.title('Discriminator Toke Loss History')
plt.xlabel('training steps')

x = range(len(disc_gaussian_loss_history))
plt.figure(dpi=100)
plt.plot(x, disc_gaussian_loss_history)
plt.title('Discriminator Gaussian Loss History')
plt.xlabel('training steps')

## Generate sequences after training

In [None]:
# states_et, states_ts, episode_token_probs, gaussian_log = generate_one_seq(generator)
# states_et.squeeze()

def recover_timedelta_to_timestamp(time_seq):
    csum = []
    curr = 0
    
    for dt in time_seq:
        if dt != 0:
            curr += dt
            csum.append(curr)
        else:
            csum.append(0)
    
    return csum

# recover_timedelta_to_timestamp(time_seq)

In [None]:
N_gen = 1000 # 
generated_seqs = []

for i in range(N_gen):
    states_et, states_ts, episode_token_probs, gaussian_log = generate_one_sequence_by_rollout(generator,
                                                                                               T, EVENT_VOCAB_DIM,
                                                                                               verbose=False)
    type_seq = states_et[-1,:,:].squeeze().tolist()
    time_seq = states_ts[-1,:,:].squeeze().tolist()
    recovered_time_seq = recover_timedelta_to_timestamp(time_seq)
    generated_seqs.append(list(zip(type_seq, recovered_time_seq)))
    if i % 50 == 0: 
        print(i)
        print(list(zip(type_seq, recovered_time_seq)))