In [None]:
import os 
import sys
from tqdm import tqdm
import importlib
import numpy as np
import pickle
import tensorflow as tf
import matplotlib.pyplot as plt

module_path = #'.../path-to-module/'
if module_path not in sys.path:
    sys.path.append(module_path)
    
from tensorflow_probability import distributions as tfd

In [None]:
if 'sgtlstm' in sys.modules:
    importlib.reload(sys.modules['sgtlstm'])
    
from sgtlstm.utils import create_dataset
from sgtlstm.SeqGan import build_G, build_D
from sgtlstm.oracle import get_G_metrics, get_hidden_metrics
from sgtlstm.pretrain import pretrain_discriminator, pretrain_generator
from sgtlstm.train import generate_sequences
from sgtlstm.TimeLSTM import TimeLSTM0, TimeLSTM1, TimeLSTM2, TimeLSTM3

import tensorflow as tf
from tensorflow.keras.optimizers import Adam, SGD

## Load data 

In [None]:
pos_data_path = # '...path-to-data/positive_long_sequences.pickle'
neg_data_path = # '...path-to-data/negative_long_sequences.pickle'

def load_sequence_from_pickle_to_numpy(pickle_file_path):
    """
        A list of sequence in format of (event_type, time_delta)
    :param pickle_file_path: e.g. /.../project-basileus/seq-gan/data/fixed_length/valid_sequences.pickle
    :return: (event_type_seqs, time_delta)
    """
    with open(pickle_file_path, 'rb') as f:
        raw_seqs = pickle.load(f)

    if not raw_seqs or not raw_seqs[0]:
        return np.array([]), np.array([])

    N = len(raw_seqs)
    T = len(raw_seqs[0])
    
    seqs = np.array(raw_seqs)
#     print(seqs.shape)
    
    et_seqs = seqs[:, :, 0].astype(np.float64).reshape((N, T, 1))
    ts_seqs = seqs[:, :, 1].astype(np.float64).reshape((N, T, 1))
    return et_seqs, ts_seqs
    
pos_event_type_seqs, pos_timestamp_seqs = load_sequence_from_pickle_to_numpy(pos_data_path)
neg_event_type_seqs, neg_timestamp_seqs = load_sequence_from_pickle_to_numpy(neg_data_path)

In [None]:
# downsample negative data
N_neg_sample = pos_event_type_seqs.shape[0]
neg_sample_idx = np.random.choice(np.arange(0, neg_event_type_seqs.shape[0]), size=N_neg_sample, replace=False)

neg_event_type_seqs = neg_event_type_seqs[neg_sample_idx,:,:]
neg_timestamp_seqs = neg_timestamp_seqs[neg_sample_idx,:,:]

## Global Variables 

In [None]:
BATCH_SIZE = 64
T = 20 + 1

# remove padding token, shift start token to 0
VOCAB = ['END/PADDING', 'INIT', 'start', 'view', 'click', 'install']
EVENT_VOCAB_DIM = len(VOCAB)
EMB_DIM = 6
HIDDEN_DIM = 100

END_TOKEN = 0
MAX_TIME = 1024

## Pretrain G

### split to train and eval

In [None]:
pretrain_G_et = pos_event_type_seqs
pretrain_G_ts = pos_timestamp_seqs
pretrain_G_labels = np.ones((pos_event_type_seqs.shape[0], 1))

pretrain_G_features = (pretrain_G_et, pretrain_G_ts)
N_total_G = pretrain_G_et.shape[0]

In [None]:
EPOCHS = 1
_TOTAL_STEPS = int(EPOCHS * N_total_G / BATCH_SIZE)


pretrain_G_dataset = create_dataset(pretrain_G_features,
                                  np.ones((N_total_G, 1)),
                                  batch_size=BATCH_SIZE,
                                  epochs=EPOCHS,
                                  buffer_size=N_total_G)


pretrain_gen_ce_loss_history = []
pretrain_gen_gaussian_loss_history = []
pretrain_gen_metrics_history = []


pretrained_generator = build_G(
    batch_size=BATCH_SIZE,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM)

In [None]:
step = 0
OPTIMIZER = Adam(lr=1e-3)
WEIGHT_GAUSSIAN_LOSS = 1

for feature_sample, _ in tqdm(pretrain_G_dataset.take(_TOTAL_STEPS)):
    step += 1
    print('Training Step:', step)
        
    gen_ce_loss, gen_gaussian_loss =  pretrain_generator(feature_sample, 
                                                         pretrained_generator,
                                                         verbose=True, 
                                                         weight_gaussian_loss=WEIGHT_GAUSSIAN_LOSS, 
                                                         optimizer=OPTIMIZER)
    # calculate G metrics 
    batch_gen_seqs = generate_sequences(BATCH_SIZE, pretrained_generator, batch_size=BATCH_SIZE, T=T, recover_to_timestamp=False)
    batch_gen_seqs = np.array(batch_gen_seqs)
    pos_sample = np.concatenate([feature_sample[0].numpy(),feature_sample[1].numpy()], axis=2)
    # batch_metrics = [rbq, fid, mad, mmd, mmd_et, mmd_ts]
    batch_metrics = get_G_metrics(pos_sample, batch_gen_seqs)
    print('batch metrics:', batch_metrics)
            
    pretrain_gen_ce_loss_history.append(gen_ce_loss.numpy())
    pretrain_gen_gaussian_loss_history.append(gen_gaussian_loss.numpy())
    pretrain_gen_metrics_history.append(batch_metrics)

In [None]:
x = range(len(pretrain_gen_ce_loss_history))
plt.figure(dpi=100)
plt.plot(x, pretrain_gen_ce_loss_history)
plt.title('Pre-training Generator Categorical Cross-Entropy Loss History')
plt.xlabel('Pre-training steps')

x = range(len(pretrain_gen_gaussian_loss_history))
plt.figure(dpi=100)
plt.plot(x, pretrain_gen_gaussian_loss_history)
plt.title('Pre-training Generator Gaussian Loss History')
plt.xlabel('training steps')

In [None]:
loss_save_dir = #'.../path-to-experiment-results/loss'
if not os.path.exists(loss_save_dir):
    os.makedirs(loss_save_dir)
    
with open(os.path.join(loss_save_dir, 'pretrain_gen_ce_loss_history.pickle'), 'wb') as f:
    pickle.dump(pretrain_gen_ce_loss_history, f)

with open(os.path.join(loss_save_dir, 'pretrain_gen_gaussian_loss_history.pickle'), 'wb') as f:
    pickle.dump(pretrain_gen_gaussian_loss_history, f)

In [None]:
metrics_save_dir = #'.../path-to-experiment-results/metrics'
if not os.path.exists(metrics_save_dir):
    os.makedirs(metrics_save_dir)
    
with open(os.path.join(metrics_save_dir, 'pretrain_gen_metrics_history.pickle'), 'wb') as f:
    pickle.dump(pretrain_gen_metrics_history, f)

### Save Pretrained G

In [None]:
model_save_dir = #'.../path-to-experiment-results/models'
if not os.path.exists(model_save_dir + '/pretrained_gen_weights'):
    os.makedirs(model_save_dir + '/pretrained_gen_weights'')

G_save_path = model_save_dir + '/pretrained_gen_weights/model.tf'

In [None]:
pretrained_generator.save_weights(G_save_path)

In [None]:
reload_pretrained_gen = build_G(
    batch_size = BATCH_SIZE,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
)

reload_pretrained_gen.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))
reload_pretrained_gen.load_weights(G_save_path)

In [None]:
reload_pretrained_gen.summary()

## Pretrain D 

In [None]:
pretrain_D_et = np.concatenate([pos_event_type_seqs, neg_event_type_seqs], axis=0)
pretrain_D_ts = np.concatenate([pos_timestamp_seqs, neg_timestamp_seqs], axis=0)

pretrain_D_labels = np.concatenate([np.ones((pos_event_type_seqs.shape[0], 1)), 
                                  np.zeros((neg_event_type_seqs.shape[0], 1))
                                 ], axis=0)
pretrain_D_features = (pretrain_D_et, pretrain_D_ts)
N_pretrain_D = pretrain_D_ts.shape[0]

In [None]:
EPOCHS = 1
_TOTAL_STEPS = int(EPOCHS * N_pretrain_D / BATCH_SIZE)

pretrain_disc_token_loss_history = []
pretrain_disc_gaussian_loss_history = []


pretrain_D_dataset = create_dataset(pretrain_D_features,
                                  pretrain_D_labels,
                                  batch_size=BATCH_SIZE,
                                  epochs=EPOCHS,
                                  buffer_size=N_pretrain_D)

pretrained_discriminator = build_D(
    T = T,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
)

In [None]:
step = 0
OPTIMIZER = Adam(lr=1e-4)
for features_batch, real_labels in tqdm(pretrain_D_dataset.take(_TOTAL_STEPS)):
    step += 1
    print('Training Step:', step)
        
    disc_token_loss = pretrain_discriminator(features_batch, real_labels, pretrained_discriminator, verbose=True, optimizer=OPTIMIZER)
    pretrain_disc_token_loss_history.append(disc_token_loss.numpy())

### pretrain D: Loss over training

In [None]:
x = range(len(pretrain_disc_token_loss_history))
plt.figure(dpi=100)
plt.plot(x, pretrain_disc_token_loss_history)
plt.title('Pre-training Discriminator CE Loss History')
plt.xlabel('Pre-training steps')

In [None]:
loss_save_dir = #'.../path-to-experiment-results/loss'
if not os.path.exists(loss_save_dir):
    os.makedirs(loss_save_dir)
    
with open(os.path.join(loss_save_dir, 'pretrain_disc_token_loss_history.pickle'), 'wb') as f:
    pickle.dump(pretrain_disc_token_loss_history, f)

### Save Pretrained D

In [None]:
model_save_dir = #'.../path-to-experiment-results/models'
if not os.path.exists(model_save_dir + '/pretrained_disc_weights'):
    os.makedirs(model_save_dir + '/pretrained_disc_weights')
    
D_save_path = model_save_dir + '/pretrained_disc_weights/pretrained_disc_weights/model.tf'

In [None]:
pretrained_discriminator.save_weights(D_save_path)

In [None]:
reload_pretrained_disc = build_D(
    T = T,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
)

reload_pretrained_disc.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))
reload_pretrained_disc.load_weights(D_save_path)

In [None]:
reload_pretrained_disc.summary()

## Generate and predict seqs

In [None]:
N_gen = 1000
generator = reload_pretrained_gen

generated_seqs = generate_sequences(N_gen, generator, batch_size=BATCH_SIZE, T=T, recover_to_timestamp=False)

In [None]:
generated_seqs

In [None]:
generated_seqs = np.array(generated_seqs)
pred_generated = reload_pretrained_disc((generated_seqs[:,:,[0]], generated_seqs[:,:,[1]]))
(pred_generated > 0.5).numpy().sum()

In [None]:
pred_generated.numpy().mean(axis=0)

In [None]:
pred_generated = reload_pretrained_disc((pos_event_type_seqs[:1000,:, :], pos_timestamp_seqs[:1000,:, :]))
(pred_generated > 0.5).numpy().sum()

In [None]:
pred_generated = reload_pretrained_disc((neg_event_type_seqs[:1000,:, :], neg_timestamp_seqs[:1000,:, :]))
(pred_generated > 0.5).numpy().sum()