In [None]:
import os 
import sys
from tqdm import tqdm
import importlib
import numpy as np
import pickle
import tensorflow as tf
import matplotlib.pyplot as plt

module_path = #'.../path-to-module/'
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
if 'sgtlstm' in sys.modules:
    importlib.reload(sys.modules['sgtlstm'])

from sgtlstm.utils import create_dataset, recover_timedelta_to_timestamp
from sgtlstm.SeqGan import build_G, build_D, build_critic
from sgtlstm.oracle import get_G_metrics, get_hidden_metrics
from sgtlstm.TimeLSTM import TimeLSTM0, TimeLSTM1, TimeLSTM2, TimeLSTM3
from sgtlstm.train import train_discriminator, train_generator, generate_sequences

## Load data 

In [None]:
pos_data_path = # '...path-to-data/positive_long_sequences.pickle'
neg_data_path = # '...path-to-data/negative_long_sequences.pickle'

def load_sequence_from_pickle_to_numpy(pickle_file_path):
    """
        A list of sequence in format of (event_type, time_delta)
    :param pickle_file_path: e.g. /.../project-basileus/seq-gan/data/fixed_length/valid_sequences.pickle
    :return: (event_type_seqs, time_delta)
    """
    with open(pickle_file_path, 'rb') as f:
        raw_seqs = pickle.load(f)

    if not raw_seqs or not raw_seqs[0]:
        return np.array([]), np.array([])

    N = len(raw_seqs)
    T = len(raw_seqs[0])
    
    seqs = np.array(raw_seqs)
#     print(seqs.shape)
    
    et_seqs = seqs[:, :, 0].astype(np.float64).reshape((N, T, 1))
    ts_seqs = seqs[:, :, 1].astype(np.float64).reshape((N, T, 1))
    return et_seqs, ts_seqs
    
pos_event_type_seqs, pos_timestamp_seqs = load_sequence_from_pickle_to_numpy(pos_data_path)
neg_event_type_seqs, neg_timestamp_seqs = load_sequence_from_pickle_to_numpy(neg_data_path)

## Global Variables 

In [None]:
BATCH_SIZE = 64
T = 20 + 1
VOCAB = ['END/PADDING', 'INIT', 'start', 'view', 'click', 'install']
EVENT_VOCAB_DIM = len(VOCAB)
EMB_DIM = 6
HIDDEN_DIM = 100

END_TOKEN = 0
MAX_TIME = 1024

## Load pretrained model

In [None]:
discriminator = build_D(
    T = T,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
)

discriminator.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

D_save_path = #'.../path-to-experiment-results/models/pretrained_disc_weights/model.tf'
discriminator.load_weights(D_save_path)

In [None]:
generator = build_G(
    batch_size=BATCH_SIZE,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM)

generator.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

G_save_path = #'.../path-to-experiment-results/models/pretrained_gen_weights/model.tf'
generator.load_weights(G_save_path)

## Create a critic network

In [None]:
critic = build_critic(
    batch_size=BATCH_SIZE,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM)

# Train G and D

In [None]:
train_et = pos_event_type_seqs
train_ts = pos_timestamp_seqs

train_labels = np.ones((pos_event_type_seqs.shape[0], 1))

train_features = (train_et, train_ts)
N_train = train_et.shape[0]
N_train

In [None]:
EPOCHS = 1

dataset = create_dataset(train_features,
                         train_labels,
                         batch_size=BATCH_SIZE,
                         epochs=EPOCHS,
                         buffer_size=N_train)

gen_token_loss_history = []
gen_gaussian_loss_history = []
disc_ce_loss_history = []
critic_network_loss_history = []
average_true_return_history = []
gen_metrics_history = []
hidden_metrics_history = []

In [None]:
_TOTAL_STEPS =  int(EPOCHS * N_train / BATCH_SIZE)
step = 0
_TOTAL_STEPS

In [None]:
G_optimizer = SGD(learning_rate=1e-4)
D_optimizer = SGD(learning_rate=1e-4)

WEIGHT_GAUSSIAN_LOSS = 1
_G_STEPS = 1
_D_STEPS = 1

for features_batch, _ in tqdm(dataset.take(_TOTAL_STEPS)):
    step += 1
    print('Training Step:', step)
    # train the generator
    for _ in range(_G_STEPS):
        gen_token_loss, gen_gaussian_loss, critic_network_loss, average_true_return = train_generator(generator, discriminator, critic, 
                                                                                 batch_size=BATCH_SIZE, T=T, verbose=True, 
                                                                                 weight_gaussian_loss=WEIGHT_GAUSSIAN_LOSS,
                                                                                 optimizer=G_optimizer
                                                                                )
        gen_token_loss_history.append(gen_token_loss.numpy())
        gen_gaussian_loss_history.append(gen_gaussian_loss.numpy())    
        critic_network_loss_history.append(critic_network_loss.numpy())
        average_true_return_history.append(average_true_return.numpy())
    
    # train the discriminator
    for _ in range(_D_STEPS):
        disc_ce_loss = train_discriminator(features_batch, generator, discriminator, 
                                           batch_size=BATCH_SIZE, T=T, verbose=True, 
                                           optimizer=D_optimizer)
        disc_ce_loss_history.append(disc_ce_loss.numpy())
        
    # calculate G  metrics 
    batch_gen_seqs = generate_sequences(BATCH_SIZE, generator, batch_size=BATCH_SIZE, T=T, recover_to_timestamp=False)
    batch_gen_seqs = np.array(batch_gen_seqs)
    pos_sample = np.concatenate([features_batch[0].numpy(),features_batch[1].numpy()], axis=2)
    # batch_metrics : [rbq, fid, mad, mmd, mmd_et, mmd_ts]
    batch_metrics = get_G_metrics(pos_sample, batch_gen_seqs)
    print('batch metrics:', batch_metrics)
    gen_metrics_history.append(batch_metrics)
    
    # calculate hidden metrics
    pos_time_comb = discriminator(features_batch)[1]
    batch_time_comb = discriminator([batch_gen_seqs[:,:,[0]], batch_gen_seqs[:,:,[1]]])[1]
    # hidden_metrics : [fid, mmd]
    hidden_metrics = get_hidden_metrics(pos_time_comb, batch_time_comb)
    print('hidden metrics:', hidden_metrics)
    hidden_metrics_history.append(hidden_metrics)
    
    # save weights every 200 steps
    if step % 100 == 0:
        print('Saving weights...')
        save_path_prefix = f'/home/lun/project-basileus/seq-gan/experiment_results/long_seqs_v10/oracle_train_{step}'
        save_model_weights(save_path_prefix, generator, discriminator, critic)
        print('All Saved!')

## Loss over training

In [None]:
x = range(len(gen_token_loss_history))
plt.figure(dpi=100)
plt.plot(x, gen_token_loss_history)
plt.title('Generator Toke Loss History')
plt.xlabel('training steps')

x = range(len(gen_gaussian_loss_history))
plt.figure(dpi=100)
plt.plot(x, gen_gaussian_loss_history)
plt.title('Generator Gaussian Loss History')
plt.xlabel('training steps')

x = range(len(disc_ce_loss_history))
plt.figure(dpi=100)
plt.plot(x, disc_ce_loss_history)
plt.title('Discriminator CE Loss History')
plt.xlabel('training steps')

In [None]:
x = range(len(critic_network_loss_history))
plt.figure(dpi=100)
plt.plot(x, critic_network_loss_history)
plt.title('Critic Loss History')
plt.xlabel('training steps')


x = range(len(average_true_return_history))
plt.figure(dpi=100)
plt.plot(x, average_true_return_history)
plt.title('Average True Return History')
plt.xlabel('training steps')

## Generate sequences 

In [None]:
N_gen = 100
generated_seqs = generate_sequences(N_gen, generator, batch_size=BATCH_SIZE, T=T, recover_to_timestamp=False)

In [None]:
# for the purpose of a performance test
# we can save this sequence using np.save to '.../path-to-gan-generated/performance_test/'
generated_seqs

## predict use GAN trained D

In [None]:
generated_seqs = np.array(generated_seqs)
pred_1, _ = discriminator((generated_seqs[:,:,[0]], generated_seqs[:,:,[1]]))

In [None]:
pred_1.numpy().mean(axis=0)

## predict use pre-trained D

In [None]:
reload_pretrained_D = build_D(
    T = T,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
)

reload_pretrained_D.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

D_save_path = #'.../path-to-experiment-results/models/pretrained_disc_weights/model.tf'
reload_pretrained_D.load_weights(D_save_path)

In [None]:
generated_seqs = np.array(generated_seqs)
pred_2, _ = reload_pretrained_D((generated_seqs[:,:,[0]], generated_seqs[:,:,[1]]))

In [None]:
pred_2.numpy().mean(axis=0)

## Save G and D models and loss

In [None]:
loss_save_dir = #'.../path-to-experiment-results/loss'
if not os.path.exists(loss_save_dir):
    os.makedirs(loss_save_dir)

with open(os.path.join(loss_save_dir, 'gen_token_loss_history.pickle'), 'wb') as f:
    pickle.dump(gen_token_loss_history, f)

with open(os.path.join(loss_save_dir, 'gen_gaussian_loss_history.pickle'), 'wb') as f:
    pickle.dump(gen_gaussian_loss_history, f)
    
with open(os.path.join(loss_save_dir, 'critic_network_loss_history.pickle'), 'wb') as f:
    pickle.dump(critic_network_loss_history, f)

with open(os.path.join(loss_save_dir, 'disc_ce_loss_history.pickle'), 'wb') as f:
    pickle.dump(disc_ce_loss_history, f)

In [None]:
metrics_save_dir = #'.../path-to-experiment-results/metrics'
if not os.path.exists(metrics_save_dir):
    os.makedirs(metrics_save_dir)
    
with open(os.path.join(metrics_save_dir, 'gen_metrics_history.pickle'), 'wb') as f:
    pickle.dump(gen_metrics_history, f)

with open(os.path.join(metrics_save_dir, 'hidden_metrics_history.pickle'), 'wb') as f:
    pickle.dump(hidden_metrics_history, f)

In [None]:
D_save_dir = #'.../path-to-experiment-results/models/disc_weights'
if not os.path.exists(D_save_dir):
    os.makedirs(D_save_dir)
    
D_save_path = os.path.join(D_save_dir, 'disc_model.tf')
discriminator.save_weights(D_save_path)

In [None]:
G_save_dir = #'.../path-to-experiment-results/models/gen_weights'
if not os.path.exists(G_save_dir):
    os.makedirs(G_save_dir)
    
G_save_path = os.path.join(G_save_dir, 'gen_model.tf')
generator.save_weights(G_save_path)

In [None]:
critic_save_dir = #'.../path-to-experiment-results/models/critic_weights'
if not os.path.exists(critic_save_dir):
    os.makedirs(critic_save_dir)
    
critic_save_path = os.path.join(critic_save_dir, 'critic_model.tf')
critic.save_weights(critic_save_path)

In [None]:
def save_model_weights(save_path_prefix, G, D, critic):
    G_save_path = os.path.join(save_path_prefix, 'gen_weights',  'gen_model.tf')
    G.save_weights(G_save_path)
    print('G saved to:', G_save_path)
    
    D_save_path = os.path.join(save_path_prefix, 'disc_weights', 'disc_model.tf')
    D.save_weights(D_save_path)
    print('D saved to:', D_save_path)
    
    critic_save_path = os.path.join(save_path_prefix, 'critic_weights', 'critic_model.tf')
    critic.save_weights(critic_save_path)                
    print('Critic saved to:', critic_save_path)    