In [1]:
import os 
import sys
from tqdm import tqdm
import importlib
import numpy as np
import pickle
import tensorflow as tf
import matplotlib.pyplot as plt

module_path = '/home/lun/project-basileus/seq-gan/'
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
if 'sgtlstm' in sys.modules:
    importlib.reload(sys.modules['sgtlstm'])

from sgtlstm.utils import load_fixed_length_sequence_from_pickle, create_dataset, recover_timedelta_to_timestamp
from sgtlstm.SeqGan import build_G, build_D, build_critic
from sgtlstm.TimeLSTM import TimeLSTM0, TimeLSTM1, TimeLSTM2, TimeLSTM3
# from sgtlstm.train import generate_batch_sequence_by_rollout, train_discriminator, train_generator

## Load data 

In [3]:
pos_data_path = '/home/lun/project-basileus/seq-gan/data/long_seqs_v5/positive_long_sequences.pickle'
neg_data_path = '/home/lun/project-basileus/seq-gan/data/long_seqs_v5/negative_long_sequences.pickle'

pos_event_type_seqs, pos_timestamp_seqs = load_fixed_length_sequence_from_pickle(pos_data_path, to_timedelta=True, end_token=0)
neg_event_type_seqs, neg_timestamp_seqs = load_fixed_length_sequence_from_pickle(neg_data_path, to_timedelta=True, end_token=0)

## Global Variables 

In [4]:
BATCH_SIZE = 64
T = 20 + 1
VOCAB = ['END/PADDING', 'INIT', 'start', 'view', 'click', 'install']
EVENT_VOCAB_DIM = len(VOCAB)
EMB_DIM = 5
HIDDEN_DIM = 32

END_TOKEN = 0
MAX_TIME = 1024

## Load pretrained model

In [28]:
discriminator = build_D(
    T = T,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
)

# discriminator.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

# D_save_path = './experiment_results/long_seqs_no_gm/init_pretrained/pretrained_disc_weights/model.tf'
# discriminator.load_weights(D_save_path)

In [29]:
generator = build_G(
    batch_size=BATCH_SIZE,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM)

generator.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

G_save_path = './experiment_results/long_seqs_no_gm/init_pretrained/pretrained_gen_weights/model.tf'
generator.load_weights(G_save_path)

minimum variance 1 !


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f7582f91fd0>

## Create a critic network

In [30]:
critic = build_critic(
    batch_size=BATCH_SIZE,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM)

using sigmoid!


# functions in training

In [31]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import Sequential
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.layers import Input, LSTM, Embedding, Reshape, Dense
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

from sgtlstm.TimeLSTM import TimeLSTM0, TimeLSTM1, TimeLSTM2, TimeLSTM3

tf.keras.backend.set_floatx('float64')

In [32]:
def generate_batch_sequence_by_rollout(
        G, batch_size, T, end_token=0, init_token=1.0, max_time=1024, verbose=False):
    # Begin from dummy init state (init_token=1, init_timestamp=0.0)
    curr_state_et = np.zeros([batch_size, 1, 1])
    curr_state_et[:, 0, 0] = init_token

    curr_state_ts = np.zeros([batch_size, 1, 1])
    curr_state_ts[:, 0, 0] = 0.0

    all_state_et = curr_state_et
    all_state_ts = curr_state_ts

    episode_token_probs = tf.constant(1., dtype=tf.float64, shape=(batch_size, 1))
    gaussian_log = tf.constant(0., dtype=tf.float64, shape=(batch_size, 1))

    G.reset_states()

    for step in range(1, T):  # sequence length
        token_prob, time_out = G([curr_state_et, curr_state_ts])

        sampled_et = tf.random.categorical(tf.math.log(token_prob), num_samples=1, dtype=tf.int32)
        sampled_et = tf.reshape(sampled_et, [batch_size, 1, 1]).numpy().astype(float)

        # get the chosen token probability per batch for each step
        sampled_et_indices = sampled_et.squeeze().astype(int).tolist()
        sampled_token_prob = token_prob.numpy()[np.arange(len(token_prob)), sampled_et_indices].reshape((batch_size, 1))
        episode_token_probs = tf.concat([episode_token_probs, sampled_token_prob], axis=1)

        # stop genererating once hit end_token
        cond_end_token = tf.equal(curr_state_et, end_token)
        curr_state_et = tf.where(cond_end_token, curr_state_et, sampled_et)
        all_state_et = tf.concat([all_state_et, curr_state_et], axis=1)

        # generate one timstamp using time_out
        sampled_ts_raw = time_out.sample()
        sampled_ts = tf.clip_by_value(tf.reshape(sampled_ts_raw, (batch_size, 1, 1))
                                      , clip_value_min=1, clip_value_max=max_time)

        # get the gaussian log likelihood for the sampled timestamps
        sampled_gaussian_log = time_out.log_prob(tf.reshape(sampled_ts, (batch_size, 1)))
        gaussian_log = tf.concat([gaussian_log, sampled_gaussian_log], axis=1)

        # stop generating once hit end_token
        curr_state_ts = tf.where(cond_end_token, curr_state_et, sampled_ts)
        all_state_ts = tf.concat([all_state_ts, curr_state_ts], axis=1)

    return all_state_et, all_state_ts, episode_token_probs, gaussian_log


In [33]:
def generate_sequences(N_gen, generator, batch_size, T, recover_to_timestamp=True):
    """
        Generate sequences batch per batch
    :param N_gen: total number of seqs to be generated
    :param generator:
    :param batch_size:
    :param T:
    :param recover_to_timestamp: whether to recover time deltas to absolute timestamps
    :return: a python list of shape [N_gen, T, 2]
    """
    all_type_seq = None
    all_time_seq = None
    N = 0

    while N < N_gen:
        batch_state_et, batch_state_ts, _, _ = generate_batch_sequence_by_rollout(generator, batch_size, T,
                                                                                  end_token=0, init_token=1.0,
                                                                                  max_time=1024, verbose=False)

        batch_type_seq = batch_state_et.numpy()
        batch_time_seq = batch_state_ts.numpy()

        # recover time delta to time stamps
        if recover_to_timestamp:
            batch_time_seq = np.cumsum(batch_time_seq, axis=1)

        if all_type_seq is None:
            all_type_seq = batch_type_seq
        else:
            all_type_seq = np.concatenate([all_type_seq, batch_type_seq], axis=0)

        if all_time_seq is None:
            all_time_seq = batch_time_seq
        else:
            all_time_seq = np.concatenate([all_time_seq, batch_time_seq], axis=0)

        N += batch_size

    # concat type and time in depth
    concated_seq_list = np.concatenate([all_type_seq, all_time_seq], axis=2).tolist()

    return concated_seq_list[:N_gen]


In [34]:
def train_generator(generator, discriminator, critic_network, batch_size, T, verbose=False,
                    weight_gaussian_loss=1,
                    optimizer=Adam(lr=0.001)):
    # reset hidden states for critic network per batch
    critic_network.reset_states()

    with tf.GradientTape(persistent=True) as tape:
        states_et, states_ts, episode_token_probs, gaussian_log = generate_batch_sequence_by_rollout(generator,
                                                                                                     batch_size, T,
                                                                                                     end_token=0,
                                                                                                     init_token=1.0,
                                                                                                     max_time=1024,
                                                                                                     verbose=False)
        ce_loss_list = []
        gaussian_list = []
        critic_loss_list = []

        # run disc on whole sequence
        # true_return is the total reward for generating this seq
        true_return = discriminator((states_et, states_ts))

        for i in range(0, T):
            # TODO: should we include the init token in loss?
            curr_state_et = states_et[:, i:i + 1, :]
            curr_state_ts = states_ts[:, i:i + 1, :]

            curr_token_prob = episode_token_probs[:, i:i + 1]
            curr_gaussian_log = gaussian_log[:, i:i + 1]

            q_value = critic_network([curr_state_et, curr_state_ts])
            advantage = true_return - q_value

            # At this point in history, the critic estimated that we would get a
            # total reward = `q_value` in the future. We took an action with log probability
            # of `log_prob` and ended up recieving a total reward = `true_return`.
            # The actor must be updated so that it predicts an action that leads to
            # high rewards (compared to critic's estimate) with high probability.

            mask = tf.squeeze(curr_state_et)
            curr_state_et = tf.boolean_mask(curr_state_et, mask)
            curr_state_ts = tf.boolean_mask(curr_state_ts, mask)
            curr_token_prob = tf.boolean_mask(curr_token_prob, mask)
            curr_gaussian_log = tf.boolean_mask(curr_gaussian_log, mask)

            masked_q_value = tf.boolean_mask(q_value, mask)
            masked_advantage = tf.boolean_mask(advantage, mask)
            masked_true_return = tf.boolean_mask(true_return, mask)            
            
            ce_loss_list.append(-tf.reduce_mean(tf.math.log(curr_token_prob) * masked_advantage))
            gaussian_list.append(-tf.reduce_mean(curr_gaussian_log * masked_advantage))

            ce_loss_list.append(-tf.reduce_mean(tf.math.log(curr_token_prob)))
            gaussian_list.append(-tf.reduce_mean(curr_gaussian_log))

            critic_loss_list.append(tf.reduce_mean(tf.keras.losses.MSE(masked_true_return, masked_q_value)))

        total_ce_loss = tf.reduce_sum(ce_loss_list)
        total_gaussian_loss = tf.reduce_sum(gaussian_list)
        total_critic_loss = tf.reduce_sum(critic_loss_list)        
        total_generator_loss = total_ce_loss + weight_gaussian_loss * total_gaussian_loss

        average_true_return = tf.reduce_mean(true_return)

        if verbose:
            print('generator token loss:{}'.format(total_ce_loss))
            print('generator gaussian loss:{}'.format(total_gaussian_loss))
            print('generator total loss:{}'.format(total_generator_loss))
            print('generator critic loss:{}'.format(total_critic_loss))
            print('average true_return: {}'.format(average_true_return))

    # update generator
    generator_grads = tape.gradient(total_generator_loss, generator.trainable_variables)
    optimizer.apply_gradients(zip(generator_grads, generator.trainable_variables))

    # update critic network
    critic_grads = tape.gradient(total_critic_loss, critic_network.trainable_variables)
    optimizer.apply_gradients(zip(critic_grads, critic_network.trainable_variables))

    # explicitly drop tape because persistent=True
    del tape

    return total_ce_loss, total_gaussian_loss, total_critic_loss, average_true_return

In [35]:
def train_discriminator(features_batch, generator, discriminator, batch_size, T, verbose=False,
                        optimizer=Adam(lr=0.001)):
    # train the discriminator
    with tf.GradientTape() as tape:
        real_et, real_ts = features_batch
        real_labels = tf.ones((batch_size, 1))  # (batch_size, 1)

        generated_et, generated_ts, episode_token_probs, gaussian_log = generate_batch_sequence_by_rollout(generator,
                                                                                                           batch_size,
                                                                                                           T,
                                                                                                           end_token=0,
                                                                                                           init_token=1.0,
                                                                                                           max_time=1024,
                                                                                                           verbose=False)
        generated_labels = tf.zeros((batch_size, 1))

        total_et = tf.concat([generated_et, real_et], axis=0)
        total_ts = tf.concat([generated_ts, real_ts], axis=0)
        total_labels = tf.concat([generated_labels, real_labels], axis=0)

        # train discriminator
        pred_prob = discriminator((total_et, total_ts))

        # cross-entropy loss
        ce_loss = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(total_labels, pred_prob, from_logits=False))
        discriminator_loss = ce_loss

        if verbose:
            print('total discriminator loss:{}'.format(discriminator_loss))

    grads = tape.gradient(discriminator_loss, discriminator.trainable_variables)
    optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))

    return ce_loss

# Train G and D

In [36]:
train_et = pos_event_type_seqs
train_ts = pos_timestamp_seqs

train_labels = np.ones((pos_event_type_seqs.shape[0], 1))

train_features = (train_et, train_ts)
N_train = train_et.shape[0]
N_train

100000

In [37]:
optimizer = tf.keras.optimizers.Adam(lr=1e-3)
# optimizer = SGD(learning_rate=1e-4)
EPOCHS = 2
# _TOTAL_STEPS = int(EPOCHS * N_train / BATCH_SIZE)
_TOTAL_STEPS = 1000

dataset = create_dataset(train_features,
                         train_labels,
                         batch_size=BATCH_SIZE,
                         epochs=EPOCHS,
                         buffer_size=N_train)

gen_token_loss_history = []
gen_gaussian_loss_history = []
critic_network_loss_history = []
average_true_return_history = []

disc_ce_loss_history = []

WEIGHT_GAUSSIAN_LOSS = 1
_G_STEPS = 2
_D_STEPS = 1

In [None]:
step = 0

for features_batch, _ in tqdm(dataset.take(_TOTAL_STEPS)):
    step += 1
    print('Training Step:', step)
    # train the generator
    for _ in range(_G_STEPS):
        gen_token_loss, gen_gaussian_loss, critic_network_loss, average_true_return = train_generator(generator, discriminator, critic, 
                                                                                 batch_size=BATCH_SIZE, T=T, verbose=True, 
                                                                                 weight_gaussian_loss=WEIGHT_GAUSSIAN_LOSS,
                                                                                 optimizer=optimizer
                                                                                )
        gen_token_loss_history.append(gen_token_loss.numpy())
        gen_gaussian_loss_history.append(gen_gaussian_loss.numpy())    
        critic_network_loss_history.append(critic_network_loss.numpy())
        average_true_return_history.append(average_true_return.numpy())
    
    # train the discriminator
    for _ in range(_D_STEPS):
        disc_ce_loss = train_discriminator(features_batch, generator, discriminator, 
                                           batch_size=BATCH_SIZE, T=T, verbose=True, 
                                           optimizer=optimizer)
        disc_ce_loss_history.append(disc_ce_loss.numpy())

0it [00:00, ?it/s]

Training Step: 1
generator token loss:22.796257740984323
generator gaussian loss:63.61086189347006
generator total loss:86.40711963445438
generator critic loss:0.17821988375984285
average true_return: 0.3700852330782999
generator token loss:23.334678231314868
generator gaussian loss:64.32432482576189
generator total loss:87.65900305707676
generator critic loss:0.15560086589326166
average true_return: 0.37091629173850904
total discriminator loss:0.7273031986111469


1it [00:04,  4.85s/it]

Training Step: 2
generator token loss:23.033770665203203
generator gaussian loss:65.13910744838222
generator total loss:88.17287811358543
generator critic loss:0.12556359162578068
average true_return: 0.3743011502773669
generator token loss:23.414294947727374
generator gaussian loss:65.73331335445343
generator total loss:89.1476083021808
generator critic loss:0.1093400136335405
average true_return: 0.37433588839975485
total discriminator loss:0.7255377537987038


2it [00:09,  4.70s/it]

Training Step: 3
generator token loss:23.816999406150153
generator gaussian loss:66.04071858869887
generator total loss:89.85771799484903
generator critic loss:0.08085755791212974
average true_return: 0.3789667287541032
generator token loss:23.93202303982447
generator gaussian loss:66.62505085601867
generator total loss:90.55707389584315
generator critic loss:0.06844975372705968
average true_return: 0.37847192693709075
total discriminator loss:0.7226778461105166


3it [00:13,  4.68s/it]

Training Step: 4
generator token loss:24.22262777367221
generator gaussian loss:67.24242599228815
generator total loss:91.46505376596036
generator critic loss:0.047067191424493046
average true_return: 0.38291906887428484
generator token loss:23.956638556342984
generator gaussian loss:67.95692721113717
generator total loss:91.91356576748015
generator critic loss:0.03620689855677682
average true_return: 0.38322390520243355
total discriminator loss:0.7200183968763791


4it [00:17,  4.51s/it]

Training Step: 5
generator token loss:24.476752122375682
generator gaussian loss:68.483016690203
generator total loss:92.95976881257869
generator critic loss:0.02087845310003088
average true_return: 0.3894972766421816
generator token loss:24.7274019078056
generator gaussian loss:68.63263914652283
generator total loss:93.36004105432843
generator critic loss:0.0160303259643175
average true_return: 0.38954242808526607
total discriminator loss:0.7173505359325302


5it [00:22,  4.52s/it]

Training Step: 6
generator token loss:25.010843457070845
generator gaussian loss:69.4805847090945
generator total loss:94.49142816616535
generator critic loss:0.011365065490264245
average true_return: 0.39583596142623145
generator token loss:25.686498274055566
generator gaussian loss:70.25795070111732
generator total loss:95.94444897517289
generator critic loss:0.011957384290535472
average true_return: 0.3956011329570921
total discriminator loss:0.7143797737136257


6it [00:27,  4.57s/it]

Training Step: 7
generator token loss:25.674992943492434
generator gaussian loss:70.53970918739164
generator total loss:96.21470213088406
generator critic loss:0.017080892052180753
average true_return: 0.4033376383078531
generator token loss:25.971257899727096
generator gaussian loss:71.50365112321359
generator total loss:97.47490902294068
generator critic loss:0.020665192505974827
average true_return: 0.40299207064793674
total discriminator loss:0.7116050591313905


7it [00:31,  4.48s/it]

Training Step: 8
generator token loss:26.327870977901444
generator gaussian loss:71.64230314081681
generator total loss:97.97017411871825
generator critic loss:0.031011204188283896
average true_return: 0.4095607551634596
generator token loss:26.569103440870336
generator gaussian loss:71.49333919064172
generator total loss:98.06244263151206
generator critic loss:0.035201319324219923
average true_return: 0.41114424927290294
total discriminator loss:0.7084873016845451


8it [00:36,  4.57s/it]

Training Step: 9
generator token loss:26.3483098579269
generator gaussian loss:72.16592955953617
generator total loss:98.51423941746307
generator critic loss:0.05065952096493859
average true_return: 0.4202880357712048
generator token loss:26.546813618116317
generator gaussian loss:72.78650139332707
generator total loss:99.33331501144339
generator critic loss:0.04957284670112038
average true_return: 0.4201202209081799
total discriminator loss:0.7053193283783634


9it [00:40,  4.44s/it]

Training Step: 10
generator token loss:26.7068612574294
generator gaussian loss:72.7149683126914
generator total loss:99.42182957012079
generator critic loss:0.06429521464703336
average true_return: 0.4290612558977665
generator token loss:26.26406985930555
generator gaussian loss:72.23185813401382
generator total loss:98.49592799331937
generator critic loss:0.05744004025656306
average true_return: 0.42867223080976313
total discriminator loss:0.7028296814725263


10it [00:45,  4.58s/it]

Training Step: 11
generator token loss:26.68496968319332
generator gaussian loss:72.50989927829139
generator total loss:99.1948689614847
generator critic loss:0.0718003162906251
average true_return: 0.44017577529124274
generator token loss:26.69623127248059
generator gaussian loss:72.48758096337272
generator total loss:99.18381223585331
generator critic loss:0.05968645498279735
average true_return: 0.4391162038390678
total discriminator loss:0.6991138621020814


11it [00:49,  4.45s/it]

Training Step: 12
generator token loss:26.592092079857085
generator gaussian loss:72.97511386255407
generator total loss:99.56720594241115
generator critic loss:0.07120384714300262
average true_return: 0.44925975759381503
generator token loss:26.337868339885087
generator gaussian loss:72.6422354465262
generator total loss:98.98010378641129
generator critic loss:0.05976713441478809
average true_return: 0.45017882910105367
total discriminator loss:0.696074402235515


12it [00:53,  4.38s/it]

Training Step: 13
generator token loss:27.079354990885196
generator gaussian loss:71.95700624777746
generator total loss:99.03636123866265
generator critic loss:0.06809832383690391
average true_return: 0.45948497307145997
generator token loss:26.189095941801405
generator gaussian loss:72.03570056575923
generator total loss:98.22479650756063
generator critic loss:0.054632107534183405
average true_return: 0.46039412661214807
total discriminator loss:0.6931595190921043


13it [00:58,  4.53s/it]

Training Step: 14
generator token loss:26.319614790668542
generator gaussian loss:71.72721690477711
generator total loss:98.04683169544565
generator critic loss:0.06328930431922036
average true_return: 0.47263569781319736
generator token loss:26.513571997239737
generator gaussian loss:71.17014612415984
generator total loss:97.68371812139958
generator critic loss:0.047437914593335434
average true_return: 0.4717878787710107
total discriminator loss:0.6908639433175586


14it [01:02,  4.52s/it]

Training Step: 15
generator token loss:26.401247198975938
generator gaussian loss:72.01292263260012
generator total loss:98.41416983157606
generator critic loss:0.05724462430267402
average true_return: 0.4835859183872348
generator token loss:26.235683191341877
generator gaussian loss:70.9891803943406
generator total loss:97.22486358568247
generator critic loss:0.04311170360386876
average true_return: 0.48297722200700594
total discriminator loss:0.6905237579701632


15it [01:07,  4.43s/it]

Training Step: 16
generator token loss:26.02386074261586
generator gaussian loss:70.6307029067044
generator total loss:96.65456364932027
generator critic loss:0.04764662837922492
average true_return: 0.4944046016681441
generator token loss:26.251595238839933
generator gaussian loss:69.73459432955543
generator total loss:95.98618956839536
generator critic loss:0.03313822172380766
average true_return: 0.4948905521409571
total discriminator loss:0.6893196923324239


16it [01:11,  4.41s/it]

Training Step: 17
generator token loss:26.290788595479544
generator gaussian loss:70.71077059610388
generator total loss:97.00155919158342
generator critic loss:0.039863709768138494
average true_return: 0.5070263070558791
generator token loss:25.53202654018164
generator gaussian loss:70.05240711602536
generator total loss:95.58443365620701
generator critic loss:0.025116409885470144
average true_return: 0.5059656645324278
total discriminator loss:0.688751104584735


17it [01:16,  4.48s/it]

Training Step: 18
generator token loss:25.603558190545368
generator gaussian loss:69.8886111132566
generator total loss:95.49216930380197
generator critic loss:0.02799348138664117
average true_return: 0.514810638164473
generator token loss:25.68363495095087
generator gaussian loss:69.19816941485179
generator total loss:94.88180436580267
generator critic loss:0.018687046753250836
average true_return: 0.5164031905981049
total discriminator loss:0.6889913201287097


18it [01:20,  4.46s/it]

Training Step: 19
generator token loss:25.775030224420735
generator gaussian loss:68.69799318788372
generator total loss:94.47302341230446
generator critic loss:0.020774650046567294
average true_return: 0.5240970130108404
generator token loss:25.541493864157843
generator gaussian loss:68.29040254296739
generator total loss:93.83189640712523
generator critic loss:0.012346675323448991
average true_return: 0.5245098070938397
total discriminator loss:0.6891843554071122


19it [01:24,  4.40s/it]

Training Step: 20
generator token loss:24.97164649672772
generator gaussian loss:68.02104829688048
generator total loss:92.9926947936082
generator critic loss:0.012291456455730793
average true_return: 0.5308341956780287
generator token loss:24.931265540142995
generator gaussian loss:67.916123257063
generator total loss:92.84738879720601
generator critic loss:0.0068085286345757265
average true_return: 0.5297069291089757
total discriminator loss:0.685604248931835


20it [01:29,  4.35s/it]

Training Step: 21
generator token loss:25.029304598944663
generator gaussian loss:67.4062833847925
generator total loss:92.43558798373716
generator critic loss:0.007277943697897951
average true_return: 0.5385569425450634
generator token loss:24.650299947470725
generator gaussian loss:66.91150691578495
generator total loss:91.56180686325568
generator critic loss:0.0037560008734652014
average true_return: 0.5350870200995418
total discriminator loss:0.689323588665502


21it [01:33,  4.27s/it]

Training Step: 22
generator token loss:24.053657241819064
generator gaussian loss:66.11507161200056
generator total loss:90.16872885381963
generator critic loss:0.004079675366173015
average true_return: 0.53897254812927
generator token loss:24.386577143391698
generator gaussian loss:65.94432905160366
generator total loss:90.33090619499535
generator critic loss:0.0035532085551214514
average true_return: 0.541672556794875
total discriminator loss:0.6875976335950035


22it [01:37,  4.31s/it]

Training Step: 23
generator token loss:23.995985354515362
generator gaussian loss:65.29238123501543
generator total loss:89.2883665895308
generator critic loss:0.004624472265388474
average true_return: 0.5410900720173011
generator token loss:23.779215766942418
generator gaussian loss:65.20812180758791
generator total loss:88.98733757453033
generator critic loss:0.0056734973600531665
average true_return: 0.543482212676319
total discriminator loss:0.682787964274621


23it [01:41,  4.32s/it]

Training Step: 24
generator token loss:23.685527278955593
generator gaussian loss:64.71926981816141
generator total loss:88.404797097117
generator critic loss:0.009700061172476283
average true_return: 0.5396273162446742
generator token loss:23.736647514424288
generator gaussian loss:64.5368567207016
generator total loss:88.2735042351259
generator critic loss:0.010676363733499464
average true_return: 0.539510168238003
total discriminator loss:0.6815998046168097


24it [01:46,  4.29s/it]

Training Step: 25
generator token loss:23.212861010652542
generator gaussian loss:63.42175434562767
generator total loss:86.63461535628021
generator critic loss:0.014101482968548197
average true_return: 0.538047891802033
generator token loss:23.796234002980555
generator gaussian loss:63.44870764026936
generator total loss:87.24494164324992
generator critic loss:0.014931646726680537
average true_return: 0.5416757557154668
total discriminator loss:0.6839819588488769


25it [01:50,  4.25s/it]

Training Step: 26
generator token loss:23.184759424164888
generator gaussian loss:62.83783185568152
generator total loss:86.02259127984641
generator critic loss:0.019818845320198033
average true_return: 0.5350974158646197
generator token loss:22.775060543228708
generator gaussian loss:62.89999872756515
generator total loss:85.67505927079385
generator critic loss:0.01997408762519037
average true_return: 0.5367442670178909
total discriminator loss:0.6779968313760363


26it [01:55,  4.45s/it]

Training Step: 27
generator token loss:23.29252423868757
generator gaussian loss:62.8917264675956
generator total loss:86.18425070628317
generator critic loss:0.02113487682506873
average true_return: 0.5340264238373342
generator token loss:23.550831640732348
generator gaussian loss:62.362551667277465
generator total loss:85.91338330800981
generator critic loss:0.017885733753757897
average true_return: 0.5322392578921281
total discriminator loss:0.6787731579356904


27it [02:00,  4.57s/it]

Training Step: 28
generator token loss:23.00816471316952
generator gaussian loss:61.48339309045033
generator total loss:84.49155780361986
generator critic loss:0.023375081302811903
average true_return: 0.5273258916162822
generator token loss:23.44119456013441
generator gaussian loss:61.63482980963811
generator total loss:85.07602436977253
generator critic loss:0.018164002216641842
average true_return: 0.5290672258654312
total discriminator loss:0.6710334282375099


28it [02:04,  4.44s/it]

Training Step: 29
generator token loss:22.74188137226013
generator gaussian loss:62.1669864770873
generator total loss:84.90886784934743
generator critic loss:0.020138908413160544
average true_return: 0.524618376770339
generator token loss:22.966394602139637
generator gaussian loss:62.42880128350187
generator total loss:85.3951958856415
generator critic loss:0.01741674206946192
average true_return: 0.5230997316761863
total discriminator loss:0.6672496492846733


29it [02:08,  4.47s/it]

Training Step: 30
generator token loss:22.85463296631861
generator gaussian loss:61.60262682313621
generator total loss:84.45725978945482
generator critic loss:0.02117531181865566
average true_return: 0.5171306084158902
generator token loss:22.734901490497414
generator gaussian loss:60.99746863520279
generator total loss:83.73237012570021
generator critic loss:0.013914510474098862
average true_return: 0.5190289106703789
total discriminator loss:0.6651610555167831


30it [02:12,  4.38s/it]

Training Step: 31
generator token loss:22.85933956693101
generator gaussian loss:61.247970740138676
generator total loss:84.10731030706968
generator critic loss:0.017873610798097773
average true_return: 0.5118456534142235
generator token loss:22.39866084301765
generator gaussian loss:61.811923532272765
generator total loss:84.21058437529041
generator critic loss:0.013535092952218178
average true_return: 0.51082499936975
total discriminator loss:0.6586245116136997


31it [02:17,  4.33s/it]

Training Step: 32
generator token loss:22.76542836950657
generator gaussian loss:60.72324293922715
generator total loss:83.48867130873371
generator critic loss:0.016691366786326385
average true_return: 0.5055059580430563
generator token loss:22.945681884807012
generator gaussian loss:61.01471174718259
generator total loss:83.9603936319896
generator critic loss:0.010596025466975498
average true_return: 0.5041082550497501
total discriminator loss:0.6496286217195707


32it [02:21,  4.30s/it]

Training Step: 33
generator token loss:22.66210380429345
generator gaussian loss:60.32208448425556
generator total loss:82.984188288549
generator critic loss:0.014376866022323917
average true_return: 0.4959285422950529
generator token loss:22.522464567018453
generator gaussian loss:61.26494937424162
generator total loss:83.78741394126007
generator critic loss:0.00978135374693246
average true_return: 0.4966523470623383
total discriminator loss:0.6355071544556814


33it [02:26,  4.48s/it]

Training Step: 34
generator token loss:23.228424991370385
generator gaussian loss:60.69632798091045
generator total loss:83.92475297228084
generator critic loss:0.011210922590457983
average true_return: 0.4917213249917561
generator token loss:22.33730331355935
generator gaussian loss:60.02924143987009
generator total loss:82.36654475342944
generator critic loss:0.010692218355040058
average true_return: 0.4862617580558399
total discriminator loss:0.6212197091348011


34it [02:30,  4.47s/it]

Training Step: 35
generator token loss:23.501233400648385
generator gaussian loss:59.96987723505962
generator total loss:83.471110635708
generator critic loss:0.010095023689961815
average true_return: 0.4901343812367359
generator token loss:22.790678516799783
generator gaussian loss:60.38508612336358
generator total loss:83.17576464016337
generator critic loss:0.010034436954434074
average true_return: 0.48377010102865176
total discriminator loss:0.6132201268215186


35it [02:34,  4.40s/it]

Training Step: 36
generator token loss:23.1454288070856
generator gaussian loss:60.25981726742484
generator total loss:83.40524607451044
generator critic loss:0.012760537047320641
average true_return: 0.48682523509755904
generator token loss:22.61061212686088
generator gaussian loss:60.15677337756864
generator total loss:82.76738550442951
generator critic loss:0.01199062225035945
average true_return: 0.48406483249417287
total discriminator loss:0.593090979285626


36it [02:39,  4.36s/it]

Training Step: 37
generator token loss:23.282628256976473
generator gaussian loss:59.49235907761611
generator total loss:82.77498733459258
generator critic loss:0.02548878047271219
average true_return: 0.4832968978460723
generator token loss:22.95908201867318
generator gaussian loss:60.21561468316334
generator total loss:83.17469670183652
generator critic loss:0.02589036542208344
average true_return: 0.4828240137916905
total discriminator loss:0.5770671118622155


37it [02:43,  4.36s/it]

Training Step: 38
generator token loss:23.22741800563656
generator gaussian loss:59.5178042822234
generator total loss:82.74522228785996
generator critic loss:0.03524609700214803
average true_return: 0.48265911239680775
generator token loss:23.919448217906456
generator gaussian loss:60.11846093493378
generator total loss:84.03790915284023
generator critic loss:0.05516095108160886
average true_return: 0.4890218090951258
total discriminator loss:0.5616160265032062


38it [02:48,  4.56s/it]

Training Step: 39
generator token loss:23.2297115271329
generator gaussian loss:59.13950281328563
generator total loss:82.36921434041852
generator critic loss:0.0525381949457818
average true_return: 0.4721871464540776
generator token loss:23.639749827188464
generator gaussian loss:59.62343526992086
generator total loss:83.26318509710933
generator critic loss:0.07067610489743878
average true_return: 0.48613999772533395
total discriminator loss:0.5270708930293537


39it [02:52,  4.45s/it]

Training Step: 40
generator token loss:23.320663960719106
generator gaussian loss:58.70640276460002
generator total loss:82.02706672531913
generator critic loss:0.12611597431422697
average true_return: 0.48693762253379047
generator token loss:23.36424819348461
generator gaussian loss:58.42153115161828
generator total loss:81.7857793451029
generator critic loss:0.13351746937230174
average true_return: 0.4795491411940197
total discriminator loss:0.48214321905693663


40it [02:57,  4.40s/it]

Training Step: 41
generator token loss:22.528555466253742
generator gaussian loss:57.20463868617442
generator total loss:79.73319415242815
generator critic loss:0.15928988789045123
average true_return: 0.45970983740153537
generator token loss:23.29059098475372
generator gaussian loss:57.15169391273846
generator total loss:80.44228489749219
generator critic loss:0.18912666165017838
average true_return: 0.46339135575861795
total discriminator loss:0.45079847701817477


41it [03:01,  4.49s/it]

Training Step: 42
generator token loss:22.177224270477616
generator gaussian loss:55.31506100970049
generator total loss:77.49228528017811
generator critic loss:0.28307857677422155
average true_return: 0.4486044595986219
generator token loss:22.158133959033567
generator gaussian loss:54.812735905407614
generator total loss:76.97086986444118
generator critic loss:0.22076276482967588
average true_return: 0.43743678396914804
total discriminator loss:0.416766563674735


42it [03:06,  4.44s/it]

Training Step: 43
generator token loss:23.471333232237235
generator gaussian loss:56.33595278047733
generator total loss:79.80728601271457
generator critic loss:0.45759462335441814
average true_return: 0.4711985446901513
generator token loss:22.654873526364803
generator gaussian loss:54.847583195435
generator total loss:77.5024567217998
generator critic loss:0.41333407060527755
average true_return: 0.44230013800874224
total discriminator loss:0.4566091249267633


43it [03:10,  4.49s/it]

Training Step: 44
generator token loss:21.297131409841434
generator gaussian loss:52.16407893007462
generator total loss:73.46121033991605
generator critic loss:0.5302399480487379
average true_return: 0.3939024187049375
generator token loss:22.00596060000635
generator gaussian loss:52.54226419981792
generator total loss:74.54822479982427
generator critic loss:0.5172440455251036
average true_return: 0.42210305720229957
total discriminator loss:0.34833629829066504


44it [03:14,  4.41s/it]

Training Step: 45
generator token loss:21.503444739484014
generator gaussian loss:50.868843232445386
generator total loss:72.3722879719294
generator critic loss:0.44173922473918253
average true_return: 0.37557780413052144
generator token loss:21.132110748750257
generator gaussian loss:49.97931407244607
generator total loss:71.11142482119632
generator critic loss:0.35936316422670195
average true_return: 0.36590685222936264
total discriminator loss:0.35044553431597397


45it [03:19,  4.38s/it]

Training Step: 46
generator token loss:21.48963960350099
generator gaussian loss:50.433725147575004
generator total loss:71.923364751076
generator critic loss:0.611010585761233
average true_return: 0.3653892317148382
generator token loss:20.721275453888026
generator gaussian loss:48.61017184878301
generator total loss:69.33144730267104
generator critic loss:0.4869636529832293
average true_return: 0.3483424453679062
total discriminator loss:0.3347175495743833


46it [03:23,  4.32s/it]

Training Step: 47
generator token loss:20.482357681929752
generator gaussian loss:48.40332673394221
generator total loss:68.88568441587196
generator critic loss:0.4168871319293141
average true_return: 0.3246982381985106
generator token loss:20.509973682880226
generator gaussian loss:48.31711929266641
generator total loss:68.82709297554663
generator critic loss:0.45058675985635555
average true_return: 0.32518432407326436
total discriminator loss:0.29433889177758943


47it [03:28,  4.47s/it]

Training Step: 48
generator token loss:22.06737899013063
generator gaussian loss:49.692004227751355
generator total loss:71.75938321788199
generator critic loss:0.5509344284526495
average true_return: 0.3437741918310554
generator token loss:20.81666449275612
generator gaussian loss:48.85546162772233
generator total loss:69.67212612047845
generator critic loss:0.4481802846711809
average true_return: 0.3189099660240675
total discriminator loss:0.2710624062833745


48it [03:33,  4.57s/it]

Training Step: 49
generator token loss:20.448805130920082
generator gaussian loss:47.84920191804095
generator total loss:68.29800704896103
generator critic loss:0.34802511427839206
average true_return: 0.2803954334984248
generator token loss:21.136946547653043
generator gaussian loss:47.83696437053613
generator total loss:68.97391091818918
generator critic loss:0.39540501583586063
average true_return: 0.28840831861021465
total discriminator loss:0.24807729569080456


49it [03:37,  4.56s/it]

Training Step: 50
generator token loss:21.117147906396504
generator gaussian loss:48.01656483226624
generator total loss:69.13371273866275
generator critic loss:0.46462557560624723
average true_return: 0.2722943862743943
generator token loss:21.01152882121432
generator gaussian loss:46.79030515100615
generator total loss:67.80183397222046
generator critic loss:0.3811513829758656
average true_return: 0.26670561315750246
total discriminator loss:0.19591910271567173


50it [03:42,  4.71s/it]

Training Step: 51
generator token loss:21.525975950168036
generator gaussian loss:46.93382085250277
generator total loss:68.45979680267081
generator critic loss:0.257755211669596
average true_return: 0.24333533328931478
generator token loss:21.47906141800206
generator gaussian loss:47.469387796996095
generator total loss:68.94844921499816
generator critic loss:0.22728118349700135
average true_return: 0.2392767811946136
total discriminator loss:0.15785627974475105


51it [03:47,  4.69s/it]

Training Step: 52
generator token loss:20.200990724664287
generator gaussian loss:46.655948798606644
generator total loss:66.85693952327094
generator critic loss:0.2487474545072047
average true_return: 0.21648251843009694
generator token loss:20.901688048786642
generator gaussian loss:45.69612798126406
generator total loss:66.59781603005071
generator critic loss:0.19262166483998952
average true_return: 0.20624254787499974
total discriminator loss:0.1491863026554161


52it [03:51,  4.69s/it]

Training Step: 53
generator token loss:21.03191368790173
generator gaussian loss:45.34591590001003
generator total loss:66.37782958791176
generator critic loss:0.19453469865484796
average true_return: 0.19055319631980455
generator token loss:21.33252524498713
generator gaussian loss:46.35130000875673
generator total loss:67.68382525374386
generator critic loss:0.15564726143932758
average true_return: 0.18875387363164003
total discriminator loss:0.12176033568416295


53it [03:56,  4.70s/it]

Training Step: 54
generator token loss:21.30467732194084
generator gaussian loss:45.32045529993721
generator total loss:66.62513262187805
generator critic loss:0.17277001152801771
average true_return: 0.16954498913344845
generator token loss:21.912516922934614
generator gaussian loss:46.28998663442677
generator total loss:68.20250355736138
generator critic loss:0.15286791679194706
average true_return: 0.16692188553670675
total discriminator loss:0.1014001772187435


54it [04:01,  4.66s/it]

Training Step: 55
generator token loss:22.20187027240188
generator gaussian loss:45.727621914265164
generator total loss:67.92949218666705
generator critic loss:0.17049712629493222
average true_return: 0.1521061873340313
generator token loss:21.698143847184348
generator gaussian loss:45.951471837277865
generator total loss:67.64961568446222
generator critic loss:0.1768030690167759
average true_return: 0.1458266218198832
total discriminator loss:0.0819325612201611


55it [04:05,  4.60s/it]

Training Step: 56
generator token loss:22.053649502012814
generator gaussian loss:45.35265181872202
generator total loss:67.40630132073484
generator critic loss:0.20174232926146002
average true_return: 0.1310209485550492
generator token loss:22.312810714461854
generator gaussian loss:44.90325771377874
generator total loss:67.21606842824059
generator critic loss:0.2084066387539544
average true_return: 0.1261813753091646
total discriminator loss:0.0679498982762502


56it [04:10,  4.57s/it]

Training Step: 57
generator token loss:22.363278039472014
generator gaussian loss:45.207616050240226
generator total loss:67.57089408971224
generator critic loss:0.22414994769535118
average true_return: 0.11063292145479998
generator token loss:22.51060551187644
generator gaussian loss:44.908323578926094
generator total loss:67.41892909080254
generator critic loss:0.22279350269502113
average true_return: 0.11131820608562011
total discriminator loss:0.05584799251086978


57it [04:14,  4.47s/it]

Training Step: 58
generator token loss:22.078792079520976
generator gaussian loss:44.2265626133546
generator total loss:66.30535469287558
generator critic loss:0.2430822949160804
average true_return: 0.08817007166531968
generator token loss:22.202906836033474
generator gaussian loss:43.68686620365335
generator total loss:65.88977303968683
generator critic loss:0.23305921119148773
average true_return: 0.09400191170976313
total discriminator loss:0.04707842975817649


58it [04:19,  4.60s/it]

Training Step: 59
generator token loss:21.81619184637505
generator gaussian loss:43.10936661991027
generator total loss:64.92555846628532
generator critic loss:0.24901422230518697
average true_return: 0.07253606115510505
generator token loss:21.967643432337113
generator gaussian loss:42.839572156411755
generator total loss:64.80721558874887
generator critic loss:0.24431903251859172
average true_return: 0.06957633828041042
total discriminator loss:0.037569861416063596


59it [04:24,  4.63s/it]

Training Step: 60
generator token loss:21.54585139613039
generator gaussian loss:41.18486058026852
generator total loss:62.73071197639891
generator critic loss:0.2704432830711906
average true_return: 0.05122237507739587
generator token loss:22.07869916914607
generator gaussian loss:40.89773007013115
generator total loss:62.97642923927722
generator critic loss:0.2630166665711308
average true_return: 0.05301876471298092
total discriminator loss:0.025773741408550783


60it [04:28,  4.62s/it]

Training Step: 61
generator token loss:21.708885294532585
generator gaussian loss:40.26169615182306
generator total loss:61.970581446355645
generator critic loss:0.2758462006408947
average true_return: 0.048350893302720126
generator token loss:21.508949108022918
generator gaussian loss:39.423703831969235
generator total loss:60.93265293999215
generator critic loss:0.2740412566419569
average true_return: 0.04170203450323583
total discriminator loss:0.02455989544667282


61it [04:33,  4.61s/it]

Training Step: 62
generator token loss:21.382280151070244
generator gaussian loss:39.17866591931704
generator total loss:60.56094607038729
generator critic loss:0.2908676700672856
average true_return: 0.03287645887627792
generator token loss:21.567534847183325
generator gaussian loss:39.380122449888674
generator total loss:60.947657297072
generator critic loss:0.28799899671624424
average true_return: 0.02664109580833646
total discriminator loss:0.016083415948679992


62it [04:38,  4.74s/it]

Training Step: 63
generator token loss:21.743009253363322
generator gaussian loss:37.716410871199834
generator total loss:59.45942012456315
generator critic loss:0.29998485395132524
average true_return: 0.02009859108568799
generator token loss:21.56157334136801
generator gaussian loss:38.22871351667053
generator total loss:59.79028685803854
generator critic loss:0.29431700520726245
average true_return: 0.016851725335191813
total discriminator loss:0.008675224704860166


63it [04:42,  4.59s/it]

Training Step: 64
generator token loss:21.705832851968957
generator gaussian loss:37.664342101639946
generator total loss:59.3701749536089
generator critic loss:0.3012578702489559
average true_return: 0.012619433473253987
generator token loss:22.091864064794265
generator gaussian loss:37.44933292973168
generator total loss:59.54119699452595
generator critic loss:0.2954770502667919
average true_return: 0.009684241616608894
total discriminator loss:0.008970760978223852


64it [04:46,  4.53s/it]

Training Step: 65
generator token loss:21.683524728526415
generator gaussian loss:36.72884818786102
generator total loss:58.41237291638743
generator critic loss:0.2928090185214301
average true_return: 0.011441689317587343
generator token loss:21.41298908734032
generator gaussian loss:37.17391479401967
generator total loss:58.586903881359994
generator critic loss:0.2868788309606056
average true_return: 0.01120882290593159
total discriminator loss:0.005960980606148169


65it [04:51,  4.43s/it]

Training Step: 66
generator token loss:21.231106567869215
generator gaussian loss:36.74145940620021
generator total loss:57.972565974069425
generator critic loss:0.2890420459526869
average true_return: 0.0033935675237614865
generator token loss:21.763046342425184
generator gaussian loss:35.86213211792339
generator total loss:57.62517846034857
generator critic loss:0.2844495360531736
average true_return: 0.003381415873193089
total discriminator loss:0.0019158856611739422


66it [04:55,  4.38s/it]

Training Step: 67
generator token loss:21.74412126686343
generator gaussian loss:36.29993793388725
generator total loss:58.04405920075068
generator critic loss:0.2804667589692508
average true_return: 0.0029335256973039236
generator token loss:21.83046009473075
generator gaussian loss:36.71970531295939
generator total loss:58.550165407690145
generator critic loss:0.273577605462377
average true_return: 0.0058008448675067725
total discriminator loss:0.0016254927621651334


67it [04:59,  4.45s/it]

Training Step: 68
generator token loss:22.085032313036496
generator gaussian loss:35.88075244284408
generator total loss:57.96578475588058
generator critic loss:0.2723421320603729
average true_return: 0.002820426774434274
generator token loss:21.822246376277043
generator gaussian loss:35.09830177341876
generator total loss:56.920548149695804
generator critic loss:0.2689910103569291
average true_return: 0.002824536567861387
total discriminator loss:0.003109534414382967


68it [05:04,  4.55s/it]

Training Step: 69
generator token loss:21.96175991129418
generator gaussian loss:34.70026947609432
generator total loss:56.6620293873885
generator critic loss:0.2680463256128896
average true_return: 7.742822684453824e-07
generator token loss:22.183894381632673
generator gaussian loss:33.53302456408571
generator total loss:55.716918945718376
generator critic loss:0.2567660771247984
average true_return: 0.011052196175088562
total discriminator loss:0.0015199301229287241


69it [05:08,  4.44s/it]

Training Step: 70
generator token loss:22.34480551143158
generator gaussian loss:34.64247873742892
generator total loss:56.9872842488605
generator critic loss:0.2586601828120188
average true_return: 0.005435327152692973
generator token loss:21.89454615537613
generator gaussian loss:33.79780593806558
generator total loss:55.69235209344171
generator critic loss:0.2540190558252612
average true_return: 0.00814267198033907
total discriminator loss:0.0029886868243601786


70it [05:13,  4.43s/it]

Training Step: 71
generator token loss:21.92741167530177
generator gaussian loss:33.5890322045802
generator total loss:55.51644387988197
generator critic loss:0.2582370345418731
average true_return: 2.3138040816052606e-09
generator token loss:21.89242687308036
generator gaussian loss:33.336913325530055
generator total loss:55.229340198610416
generator critic loss:0.25235907147348763
average true_return: 0.005347315168114906
total discriminator loss:0.001466175830029642


71it [05:18,  4.53s/it]

Training Step: 72
generator token loss:22.395775683608935
generator gaussian loss:33.72128431509552
generator total loss:56.11705999870445
generator critic loss:0.2528726308008455
average true_return: 0.0026330087484324484
generator token loss:22.037114992783735
generator gaussian loss:33.64924601938603
generator total loss:55.68636101216976
generator critic loss:0.24471581329864026
average true_return: 0.010523085844987266
total discriminator loss:0.0


72it [05:22,  4.49s/it]

Training Step: 73
generator token loss:21.455148473055885
generator gaussian loss:32.769143184996395
generator total loss:54.22429165805228
generator critic loss:0.24329621603048668
average true_return: 0.01038692930296399
generator token loss:21.563760593048798
generator gaussian loss:33.29358744596279
generator total loss:54.85734803901159
generator critic loss:0.2460548736850135
average true_return: 0.005190514287010596
total discriminator loss:0.002838670786486432


73it [05:26,  4.43s/it]

Training Step: 74
generator token loss:22.21627831146531
generator gaussian loss:32.607239671704924
generator total loss:54.823517983170234
generator critic loss:0.24895610274657715
average true_return: 6.340765468469944e-13
generator token loss:22.486175829609515
generator gaussian loss:32.847945189375494
generator total loss:55.33412101898501
generator critic loss:0.24611392688945352
average true_return: 0.002567299777306384
total discriminator loss:0.0028011252886475657


74it [05:31,  4.60s/it]

Training Step: 75
generator token loss:21.928803413723948
generator gaussian loss:33.03768080269702
generator total loss:54.966484216420966
generator critic loss:0.24238358982904043
average true_return: 0.005066392447258903
generator token loss:22.218242095336684
generator gaussian loss:32.37942663330702
generator total loss:54.597668728643704
generator critic loss:0.23747063781385536
average true_return: 0.010132407790184098
total discriminator loss:0.0013814266869134647


75it [05:36,  4.49s/it]

## 1000 steps: Loss over training

In [None]:
x = range(len(gen_token_loss_history))
plt.figure(dpi=100)
plt.plot(x, gen_token_loss_history)
plt.title('Generator Toke Loss History')
plt.xlabel('training steps')

x = range(len(gen_gaussian_loss_history))
plt.figure(dpi=100)
plt.plot(x, gen_gaussian_loss_history)
plt.title('Generator Gaussian Loss History')
plt.xlabel('training steps')

x = range(len(disc_ce_loss_history))
plt.figure(dpi=100)
plt.plot(x, disc_ce_loss_history)
plt.title('Discriminator CE Loss History')
plt.xlabel('training steps')

In [None]:
x = range(len(critic_network_loss_history))
plt.figure(dpi=100)
plt.plot(x, critic_network_loss_history)
plt.title('Critic Loss History')
plt.xlabel('training steps')


x = range(len(average_true_return_history))
plt.figure(dpi=100)
plt.plot(x, average_true_return_history)
plt.title('Average True Return History')
plt.xlabel('training steps')

In [None]:
loss_save_dir = './experiment_results/after_1000_steps/loss'
if not os.path.exists(loss_save_dir):
    os.makedirs(loss_save_dir)

with open(os.path.join(loss_save_dir, 'gen_token_loss_history.pickle'), 'wb') as f:
    pickle.dump(gen_token_loss_history, f)

with open(os.path.join(loss_save_dir, 'gen_gaussian_loss_history.pickle'), 'wb') as f:
    pickle.dump(gen_gaussian_loss_history, f)
    
with open(os.path.join(loss_save_dir, 'critic_network_loss_history.pickle'), 'wb') as f:
    pickle.dump(critic_network_loss_history, f)

with open(os.path.join(loss_save_dir, 'disc_ce_loss_history.pickle'), 'wb') as f:
    pickle.dump(disc_ce_loss_history, f)

## 1000 steps: Save G and D models

In [None]:
G_save_dir = './experiment_results/after_1000_steps/gen_weights'
if not os.path.exists(G_save_dir):
    os.makedirs(G_save_dir)
    
G_save_path = os.path.join(G_save_dir, 'gen_model.tf')
generator.save_weights(G_save_path)

In [None]:
!pwd

In [None]:
D_save_dir = './experiment_results/after_1000_steps/disc_weights'
if not os.path.exists(D_save_dir):
    os.makedirs(D_save_dir)
    
D_save_path = os.path.join(D_save_dir, 'disc_model.tf')
discriminator.save_weights(D_save_path)

## 1000 steps: Generate sequences 

In [None]:
N_gen = 10000
generated_seqs = generate_sequences(N_gen, generator, batch_size=BATCH_SIZE, T=T, recover_to_timestamp=True)

In [None]:
generated_seqs_save_dir = './experiment_results/after_1000_steps/generated_seqs'
if not os.path.exists(generated_seqs_save_dir):
    os.makedirs(generated_seqs_save_dir)

with open(os.path.join(generated_seqs_save_dir, 'generated_seqs.pickle'), 'wb') as f:
    pickle.dump(generated_seqs, f)

In [None]:
generated_seqs

## Debug Session 

In [None]:
batch_size = BATCH_SIZE
init_token = 1
end_token = 0
max_time = 1024

In [None]:
reload_G = build_G(
    batch_size=BATCH_SIZE,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM)

reload_G.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

G_save_path = './experiment_results/long_seqs_no_gm/init_pretrained/pretrained_gen_weights/model.tf'
reload_G.load_weights(G_save_path)

G = reload_G

In [None]:
# Begin from dummy init state (init_token=1, init_timestamp=0.0)
curr_state_et = np.zeros([batch_size, 1, 1])
curr_state_et[:, 0, 0] = init_token

curr_state_ts = np.zeros([batch_size, 1, 1])
curr_state_ts[:, 0, 0] = 0.0

all_state_et = curr_state_et
all_state_ts = curr_state_ts

episode_token_probs = tf.constant(1., dtype=tf.float64, shape=(batch_size, 1))
gaussian_log = tf.constant(0., dtype=tf.float64, shape=(batch_size, 1))

G.reset_states()

for step in range(1, T):  # sequence length
    token_prob, time_out = G([curr_state_et, curr_state_ts])

    sampled_et = tf.random.categorical(tf.math.log(token_prob), num_samples=1, dtype=tf.int32)
    sampled_et = tf.reshape(sampled_et, [batch_size, 1, 1]).numpy().astype(float)

    # get the chosen token probability per batch for each step
    sampled_et_indices = sampled_et.squeeze().astype(int).tolist()
    sampled_token_prob = token_prob.numpy()[np.arange(len(token_prob)), sampled_et_indices].reshape((batch_size, 1))
    episode_token_probs = tf.concat([episode_token_probs, sampled_token_prob], axis=1)

    # stop genererating once hit end_token
    cond_end_token = tf.equal(curr_state_et, end_token)
    curr_state_et = tf.where(cond_end_token, curr_state_et, sampled_et)
    all_state_et = tf.concat([all_state_et, curr_state_et], axis=1)

    # generate one timstamp using time_out
    sampled_ts_raw = time_out.sample()
    sampled_ts = tf.clip_by_value(tf.reshape(sampled_ts_raw, (batch_size, 1, 1))
                                  , clip_value_min=1, clip_value_max=max_time)

    # get the gaussian log likelihood for the sampled timestamps
    sampled_gaussian_log = time_out.log_prob(sampled_ts_raw)
    gaussian_log = tf.concat([gaussian_log, sampled_gaussian_log], axis=1)

    # stop generating once hit end_token
    curr_state_ts = tf.where(cond_end_token, curr_state_ts, sampled_ts)
    all_state_ts = tf.concat([all_state_ts, curr_state_ts], axis=1)