In [1]:
import os 
import sys
from tqdm import tqdm
import pickle
import importlib
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

module_path = '/home/lun/project-basileus/seq-gan/sgtlstm'
if module_path not in sys.path:
    sys.path.append(module_path)

In [13]:
if 'sgtlstm' in sys.modules:
    importlib.reload(sys.modules['sgtlstm'])

from sgtlstm.utils import load_fixed_length_sequence_from_pickle, create_dataset
from sgtlstm.utils import create_dataset
from sgtlstm.SeqGan import build_G, build_D
from sgtlstm.train import generate_one_sequence_by_rollout
from sgtlstm.pretrain import pretrain_generator, pretrain_discriminator, create_self_regression_data_batch
from sgtlstm.TimeLSTM import TimeLSTM0, TimeLSTM1, TimeLSTM2, TimeLSTM3

In [3]:
import tensorflow as tf
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.layers import LSTM

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import Sequential
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.layers import Input, LSTM, Embedding, Reshape, Dense
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

In [4]:
BATCH_SIZE = 128
T = 10 + 1
VOCAB = ['END/PADDING', 'INIT', 'start', 'click', 'install']
EVENT_VOCAB_DIM = len(VOCAB)
EMB_DIM = 5
HIDDEN_DIM = 64
K_MIST = 2

END_TOKEN = 0
MAX_TIME = 1024

## Load data 

In [5]:
valid_data_path = '/home/lun/project-basileus/seq-gan/data/fixed_length_with_init_token/valid_sequences.pickle'
invalid_data_path = '/home/lun/project-basileus/seq-gan/data/fixed_length_with_init_token/invalid_sequences.pickle'

valid_event_type_seqs, valid_timestamp_seqs = load_fixed_length_sequence_from_pickle(valid_data_path, to_timedelta=True, end_token=0)
invalid_event_type_seqs, invalid_timestamp_seqs = load_fixed_length_sequence_from_pickle(invalid_data_path, to_timedelta=True, end_token=0)

## Load pretrained G model

In [6]:
generator = build_G(
    T = T,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
    k_mixt = K_MIST,
    return_sequence=False,
)
generator.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

G_save_path = './gan_model_weights/pretrain_20000_gen_clip/model.tf'
generator.load_weights(G_save_path)

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Bad argument number for Name: 3, expecting 4
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: Bad argument number for Name: 3, expecting 4
sigma > 1 !!


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7ff3a5209410>

In [16]:
curr_state_et = np.array([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).reshape((1, T, 1))
curr_state_ts = np.array([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]).reshape((1, T, 1))

In [17]:
pred_token_prob, gaussian_log, mask, alpha, mu, sigma = generator.predict((curr_state_et, curr_state_ts))


In [18]:
pred_token_prob

array([[3.62814531e-05, 3.44735343e-05, 9.40776198e-01, 3.36491809e-02,
        2.55038664e-02]])

In [22]:
# states_et, states_ts, episode_token_probs, gaussian_log = generate_one_seq(generator)
# states_et.squeeze()

def recover_timedelta_to_timestamp(time_seq):
    csum = []
    curr = 0
    
    for dt in time_seq:
        if dt != 0:
            curr += dt
            csum.append(curr)
        else:
            csum.append(0)
    
    return csum

def generate_one_sequence_by_rollout_2(generator, T, event_vocab_dim, end_token=0, init_token=1, max_time=1024, verbose=False):
    # Begin from dummy init state (init_token=1, init_timestamp=0.0)
    curr_state_et = np.zeros([T])
    curr_state_et[0] = init_token
    curr_state_et = curr_state_et.reshape((1, T, 1))

    curr_state_ts = np.zeros([T])
    curr_state_ts[0] = 0.0
    curr_state_ts = curr_state_ts.reshape((1, T, 1))

    # whole trajectory
    states_et = (curr_state_et)
    states_ts = (curr_state_ts)
    episode_token_probs = tf.constant([1., ], dtype=tf.float64)

    for step in range(1, 3):  # sequence length
        token_prob, gaussian_log, mask, alpha, mu, sigma = generator([curr_state_et, curr_state_ts])

        # generate one timstamp using [alpha, mu, sigma]
        gm = tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(
                probs=alpha),
            components_distribution=tfd.Normal(
                loc=mu,
                scale=sigma))

        # sample next event token and time stamp
        sampled_et = tf.random.categorical(token_prob, num_samples=1)
        sampled_ts = tf.clip_by_value(gm.sample(), clip_value_min=1, clip_value_max=max_time)  # shape=[BATCH_SIZE,]

        taken_action_idx = sampled_et.numpy().item()

        if taken_action_idx == end_token:
            if verbose:
                print('Generation ended early!')
            break  # episode is over

        taken_action_prob = token_prob[0][taken_action_idx]
        taken_action_prob = tf.reshape(taken_action_prob, [1, ])
        episode_token_probs = tf.concat([episode_token_probs, taken_action_prob], axis=0)

        new_state_et = np.copy(curr_state_et)
        new_state_ts = np.copy(curr_state_ts)

        # TODO: 0 means 1 generation per batch
        new_state_et[0, step, :] = sampled_et
        new_state_ts[0, step, :] = sampled_ts

        if verbose:
            print('new_state_et', tf.squeeze(new_state_et))

        states_et = np.concatenate((states_et, new_state_et))
        states_ts = np.concatenate((states_ts, new_state_ts))

        curr_state_et = new_state_et
        curr_state_ts = new_state_ts
        if verbose:
            print('Generation done!')

    if verbose:
        print('episode length={}'.format(step + 1))
        print('state_et =', states_et)
        print('state_ts =', states_ts)
        print('episode_token_probs =', episode_token_probs)
        print('gaussian_log =', gaussian_log)

    return states_et, states_ts, episode_token_probs, gaussian_log

In [23]:
N_gen = 100 # 
generated_seqs = []

for i in tqdm(range(N_gen)):
    states_et, states_ts, episode_token_probs, gaussian_log = generate_one_sequence_by_rollout_2(generator,
                                                                                               T, EVENT_VOCAB_DIM,
                                                                                               verbose=False)
    type_seq = states_et[-1,:,:].squeeze().tolist()
    time_seq = states_ts[-1,:,:].squeeze().tolist()
    recovered_time_seq = recover_timedelta_to_timestamp(time_seq)
    generated_seqs.append(list(zip(type_seq, recovered_time_seq)))
    if i % 1 == 0:
        print(i)
        print(list(zip(type_seq, recovered_time_seq)))

  2%|▏         | 2/100 [00:00<00:12,  7.89it/s]

0
[(1.0, 0), (2.0, 1.0), (2.0, 10.245257178770935), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]
1
[(1.0, 0), (1.0, 1.0), (3.0, 13.766319805319563), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]


  4%|▍         | 4/100 [00:00<00:11,  8.03it/s]

2
[(1.0, 0), (2.0, 1.0), (1.0, 8.020386331176132), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]
3
[(1.0, 0), (4.0, 1.0), (1.0, 16.474744835290625), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]


  6%|▌         | 6/100 [00:00<00:11,  8.16it/s]

4
[(1.0, 0), (4.0, 1.0), (3.0, 4.392529027033715), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]
5
[(1.0, 0), (2.0, 1.0), (3.0, 9.196848684760143), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]


  8%|▊         | 8/100 [00:00<00:11,  8.22it/s]

6
[(1.0, 0), (2.0, 1.0), (2.0, 2.0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]
7
[(1.0, 0), (3.0, 1.0), (2.0, 2.6226371499719336), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]
8
[(1.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]


 11%|█         | 11/100 [00:01<00:10,  8.70it/s]

9
[(1.0, 0), (1.0, 1.0), (4.0, 12.009965967153462), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]
10
[(1.0, 0), (2.0, 1.0), (4.0, 2.0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]
11
[(1.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]


 14%|█▍        | 14/100 [00:01<00:09,  8.86it/s]

12
[(1.0, 0), (2.0, 1.0), (1.0, 8.27709095017334), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]
13
[(1.0, 0), (2.0, 1.0737952675020808), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]
14
[(1.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]


 16%|█▌        | 16/100 [00:01<00:09,  8.58it/s]

15
[(1.0, 0), (4.0, 1.0), (3.0, 4.5838130836813535), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0), (0.0, 0)]





KeyboardInterrupt: 