In [None]:
import os 
import sys
from tqdm import tqdm
import importlib
import numpy as np
import pickle
import tensorflow as tf
import matplotlib.pyplot as plt

module_path = '/home/lun/project-basileus/seq-gan/'
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
if 'sgtlstm' in sys.modules:
    importlib.reload(sys.modules['sgtlstm'])

from sgtlstm.utils import create_dataset, recover_timedelta_to_timestamp
from sgtlstm.SeqGan import build_G, build_D, build_critic
from sgtlstm.TimeLSTM import TimeLSTM0, TimeLSTM1, TimeLSTM2, TimeLSTM3
# from sgtlstm.train import generate_batch_sequence_by_rollout, train_discriminator, train_generator

## Load data 

In [None]:
pos_data_path = '/home/lun/project-basileus/seq-gan/data/long_seqs_v6/positive_long_sequences.pickle'
neg_data_path = '/home/lun/project-basileus/seq-gan/data/long_seqs_v6/negative_long_sequences.pickle'
all_data_path = '/home/lun/project-basileus/seq-gan/data/long_seqs_v6/all_long_sequences.pickle'

def load_sequence_from_pickle_to_numpy(pickle_file_path):
    """
        A list of sequence in format of (event_type, delta_time)
    :param pickle_file_path: e.g. /.../project-basileus/seq-gan/data/fixed_length/valid_sequences.pickle
    :return: (event_type_seqs, delta_time)
    """
    with open(pickle_file_path, 'rb') as f:
        raw_seqs = pickle.load(f)

    if not raw_seqs or not raw_seqs[0]:
        return np.array([]), np.array([])

    N = len(raw_seqs)
    T = len(raw_seqs[0])
    
    seqs = np.array(raw_seqs)
#     print(seqs.shape)
    
    et_seqs = seqs[:, :, 0].astype(np.float64).reshape((N, T, 1))
    ts_seqs = seqs[:, :, 1].astype(np.float64).reshape((N, T, 1))
    return et_seqs, ts_seqs
    
pos_event_type_seqs, pos_timestamp_seqs = load_sequence_from_pickle_to_numpy(pos_data_path)
neg_event_type_seqs, neg_timestamp_seqs = load_sequence_from_pickle_to_numpy(neg_data_path)
all_event_type_seqs, all_timestamp_seqs = load_sequence_from_pickle_to_numpy(all_data_path)

## Global Variables 

In [None]:
BATCH_SIZE = 64
T = 20 + 1
VOCAB = ['END/PADDING', 'INIT', 'start', 'view', 'click', 'install']
EVENT_VOCAB_DIM = len(VOCAB)
EMB_DIM = 5
HIDDEN_DIM = 32

END_TOKEN = 0
MAX_TIME = 1024

## Load pretrained model

In [None]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import Sequential
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.layers import Input, LSTM, Embedding, Reshape, Dense
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
from tensorflow.keras import regularizers

from sgtlstm.TimeLSTM import TimeLSTM0, TimeLSTM1, TimeLSTM2, TimeLSTM3

tf.keras.backend.set_floatx('float64')

def build_D_2(T, event_vocab_dim, emb_dim, hidden_dim=11):
    """
        Build a discriminator for event type sequence of shape (batch_size, T, input_dim)
        and input event type sequence of shape (batch_size, T, 1)
    :param T: length of the sequence
    :param event_vocab_dim: size of event vocabulary ['na', 'init', 'start', 'view', 'click', 'install']
    :param emb_dim: dimension of the embedding layer output for event type
    :param hidden_dim: dimension hidden of the time lstm cell
    :return: discriminator D
    """
    # Time-LSTM:
    i_et = Input(shape=(T, 1), name='event_type')  # input of discrete feature event type
    i_ts = Input(shape=(T, 1), name='time_delta')  # input of continuous feature timestamp
    mask_layer = tf.keras.layers.Masking(mask_value=0., input_shape=(T, 1))
    masked_ts = mask_layer(i_ts)
    masked_et = mask_layer(i_et)

    embed0 = Embedding(input_dim=event_vocab_dim, output_dim=emb_dim, input_length=T, mask_zero=True)(masked_et)
    embed0 = Reshape((T, emb_dim))(embed0)  # shape=[Batch_size, T, emb_dim]
    merged0 = tf.keras.layers.concatenate([embed0, masked_ts], axis=2)  # # shape=[Batch_size, T, emb_dim + time_dim]

    hm, tm = TimeLSTM1(hidden_dim, activation='selu', name='time_lstm', return_sequences=False)(merged0)

    time_comb = tf.keras.layers.concatenate([hm, tm], axis=1)

    # predicted real prob
    real_prob = Dense(1, activation='sigmoid', name='fraud_prob', kernel_regularizer=regularizers.l1_l2(l1=1e-3, l2=1e-3))(
        time_comb)

    discriminator = Model(
        inputs=[i_et, i_ts],
        outputs=[real_prob])

    return discriminator

In [None]:
discriminator = build_D_2(
    T = T,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM,
)

discriminator.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

D_save_path = '/home/lun/project-basileus/seq-gan/experiment_results/long_seqs_v6/init_pretrained/pretrained_disc_weights/model_2.tf'
discriminator.load_weights(D_save_path)

In [None]:
generator = build_G(
    batch_size=BATCH_SIZE,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM)

generator.build(input_shape=((BATCH_SIZE, T, 1), (BATCH_SIZE, T, 1)))

G_save_path = '/home/lun/project-basileus/seq-gan/experiment_results/long_seqs_v6/init_pretrained/pretrained_gen_weights/model.tf'
generator.load_weights(G_save_path)

## Create a critic network

In [None]:
critic = build_critic(
    batch_size=BATCH_SIZE,
    event_vocab_dim = EVENT_VOCAB_DIM,
    emb_dim = EMB_DIM,
    hidden_dim= HIDDEN_DIM)

# functions in training

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras import Sequential
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.layers import Input, LSTM, Embedding, Reshape, Dense
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd

from sgtlstm.TimeLSTM import TimeLSTM0, TimeLSTM1, TimeLSTM2, TimeLSTM3

tf.keras.backend.set_floatx('float64')

In [None]:
def generate_batch_sequence_by_rollout(
        G, batch_size, T, end_token=0, init_token=1.0, max_time=1024, verbose=False):
    # Begin from dummy init state (init_token=1, init_timestamp=0.0)
    curr_state_et = tf.ones([batch_size, 1, 1], dtype=tf.float64)
    curr_state_ts = tf.zeros([batch_size, 1, 1], dtype=tf.float64)

    all_state_et = curr_state_et
    all_state_ts = curr_state_ts

    episode_token_probs = tf.constant(1., dtype=tf.float64, shape=(batch_size, 1))
    gaussian_log = tf.constant(0., dtype=tf.float64, shape=(batch_size, 1))

    G.reset_states()

    for step in range(1, T):  # sequence length
        token_prob, time_out = G([curr_state_et, curr_state_ts])

        sampled_et = tf.random.categorical(tf.math.log(token_prob), num_samples=1, dtype=tf.int32)
        sampled_et = tf.reshape(sampled_et, [batch_size, 1, 1])

        # get the chosen token probability per batch for each step
        batch_sample_et = tf.reshape(sampled_et, (batch_size, 1))
        batch_ind = tf.reshape(tf.range(0, batch_size), (batch_size, 1))
        batch_sample_et_2d = tf.concat([batch_ind, batch_sample_et], axis=1)

        sampled_token_prob = tf.reshape(tf.gather_nd(token_prob, batch_sample_et_2d), (batch_size, 1))
        episode_token_probs = tf.concat([episode_token_probs, sampled_token_prob], axis=1)

        # cast sampled_et into float
        sampled_et = tf.cast(sampled_et, dtype=tf.float64)

        # stop genererating once hit end_token
        cond_end_token = tf.equal(curr_state_et, end_token)
        curr_state_et = tf.where(cond_end_token, curr_state_et, sampled_et)
        all_state_et = tf.concat([all_state_et, curr_state_et], axis=1)

        # generate one timstamp using time_out
        sampled_ts_raw = time_out.sample()
        sampled_ts = tf.clip_by_value(tf.reshape(sampled_ts_raw, (batch_size, 1, 1))
                                      , clip_value_min=1, clip_value_max=max_time)

        # get the gaussian log likelihood for the sampled timestamps
        sampled_gaussian_log = time_out.log_prob(sampled_ts_raw)
        gaussian_log = tf.concat([gaussian_log, sampled_gaussian_log], axis=1)

        # stop generating once hit end_token
        curr_state_ts = tf.where(cond_end_token, curr_state_et, sampled_ts)
        all_state_ts = tf.concat([all_state_ts, curr_state_ts], axis=1)

    return all_state_et, all_state_ts, episode_token_probs, gaussian_log


In [None]:
def generate_sequences(N_gen, generator, batch_size, T, recover_to_timestamp=True):
    """
        Generate sequences batch per batch
    :param N_gen: total number of seqs to be generated
    :param generator:
    :param batch_size:
    :param T:
    :param recover_to_timestamp: whether to recover time deltas to absolute timestamps
    :return: a python list of shape [N_gen, T, 2]
    """
    all_type_seq = None
    all_time_seq = None
    N = 0

    while N < N_gen:
        batch_state_et, batch_state_ts, _, _ = generate_batch_sequence_by_rollout(generator, batch_size, T,
                                                                                  end_token=0, init_token=1.0,
                                                                                  max_time=1024, verbose=False)

        batch_type_seq = batch_state_et.numpy()
        batch_time_seq = batch_state_ts.numpy()

        # recover time delta to time stamps
        if recover_to_timestamp:
            batch_time_seq = np.cumsum(batch_time_seq, axis=1)

        if all_type_seq is None:
            all_type_seq = batch_type_seq
        else:
            all_type_seq = np.concatenate([all_type_seq, batch_type_seq], axis=0)

        if all_time_seq is None:
            all_time_seq = batch_time_seq
        else:
            all_time_seq = np.concatenate([all_time_seq, batch_time_seq], axis=0)

        N += batch_size

    # concat type and time in depth
    concated_seq_list = np.concatenate([all_type_seq, all_time_seq], axis=2).tolist()

    return concated_seq_list[:N_gen]


In [None]:
def train_generator(generator, discriminator, critic_network, batch_size, T, verbose=False,
                    weight_gaussian_loss=1,
                    optimizer=Adam(lr=0.001)):
    # reset hidden states for critic network per batch
    critic_network.reset_states()

    with tf.GradientTape(persistent=True) as tape:
        states_et, states_ts, episode_token_probs, gaussian_log = generate_batch_sequence_by_rollout(generator,
                                                                                                     batch_size, T,
                                                                                                     end_token=0,
                                                                                                     init_token=1.0,
                                                                                                     max_time=1024,
                                                                                                     verbose=False)
        ce_loss_list = []
        gaussian_list = []
        critic_loss_list = []

        # run disc on whole sequence
        # true_return is the total reward for generating this seq
        true_return = discriminator((states_et, states_ts))
        
        ZERO_PENALTY = 10

        for i in range(0, T):
            # TODO: should we include the init token in loss?
            curr_state_et = states_et[:, i:i + 1, :]
            curr_state_ts = states_ts[:, i:i + 1, :]

            curr_token_prob = episode_token_probs[:, i:i + 1]
            curr_gaussian_log = gaussian_log[:, i:i + 1]

            q_value = critic_network([curr_state_et, curr_state_ts])
            advantage = true_return - q_value

            # At this point in history, the critic estimated that we would get a
            # total reward = `q_value` in the future. We took an action with log probability
            # of `log_prob` and ended up recieving a total reward = `true_return`.
            # The actor must be updated so that it predicts an action that leads to
            # high rewards (compared to critic's estimate) with high probability.            
            
            mask = tf.squeeze(curr_state_et)
            curr_state_et = tf.boolean_mask(curr_state_et, mask)
            curr_state_ts = tf.boolean_mask(curr_state_ts, mask)
            curr_token_prob = tf.boolean_mask(curr_token_prob, mask)
            curr_gaussian_log = tf.boolean_mask(curr_gaussian_log, mask)
            
            if curr_state_et.shape[0] == 0:
                ce_loss_list.append(ZERO_PENALTY)
                continue

            masked_q_value = tf.boolean_mask(q_value, mask)
            masked_advantage = tf.boolean_mask(advantage, mask)
            masked_true_return = tf.boolean_mask(true_return, mask)            
            
            ce_loss_list.append(-tf.reduce_mean(tf.math.log(curr_token_prob) * masked_advantage))
            gaussian_list.append(-tf.reduce_mean(curr_gaussian_log * masked_advantage))

            ce_loss_list.append(-tf.reduce_mean(tf.math.log(curr_token_prob)))
            gaussian_list.append(-tf.reduce_mean(curr_gaussian_log))

            critic_loss_list.append(tf.reduce_mean(tf.keras.losses.MSE(masked_true_return, masked_q_value)))

        total_ce_loss = tf.reduce_sum(ce_loss_list)
        total_gaussian_loss = tf.reduce_sum(gaussian_list)
        total_critic_loss = tf.reduce_sum(critic_loss_list)        
        total_generator_loss = total_ce_loss + weight_gaussian_loss * total_gaussian_loss

        average_true_return = tf.reduce_mean(true_return)

        if verbose:
            print('generator token loss:{}'.format(total_ce_loss))
            print('generator gaussian loss:{}'.format(total_gaussian_loss))
            print('generator total loss:{}'.format(total_generator_loss))
            print('generator critic loss:{}'.format(total_critic_loss))
            print('average true_return: {}'.format(average_true_return))
            print('-----------------------')

    # update generator
    generator_grads = tape.gradient(total_generator_loss, generator.trainable_variables)
    optimizer.apply_gradients(zip(generator_grads, generator.trainable_variables))

    # update critic network
    critic_grads = tape.gradient(total_critic_loss, critic_network.trainable_variables)
    optimizer.apply_gradients(zip(critic_grads, critic_network.trainable_variables))

    # explicitly drop tape because persistent=True
    del tape

    return total_ce_loss, total_gaussian_loss, total_critic_loss, average_true_return

In [None]:
def train_discriminator(features_batch, generator, discriminator, batch_size, T, verbose=False,
                        optimizer=Adam(lr=0.001)):
    # train the discriminator
    with tf.GradientTape() as tape:
        real_et, real_ts = features_batch
        # (batch_size, 1)
        real_labels = tf.ones((batch_size, 1)) + tfd.Normal(loc=0, scale=0.1, name='normal_disturbance_true').sample(sample_shape=(batch_size, 1))                
#         real_labels = tf.clip_by_value(real_labels, clip_value_min=0.0, clip_value_max=1.0)
        

        generated_et, generated_ts, episode_token_probs, gaussian_log = generate_batch_sequence_by_rollout(generator,
                                                                                                           batch_size,
                                                                                                           T,
                                                                                                           end_token=0,
                                                                                                           init_token=1.0,
                                                                                                           max_time=1024,
                                                                                                           verbose=False)

        generated_labels = tf.zeros((batch_size, 1)) + tfd.Normal(loc=0, scale=0.1, name='normal_disturbance_fake').sample(sample_shape=(batch_size, 1))
#         generated_labels = tf.clip_by_value(generated_labels, clip_value_min=0.0, clip_value_max=1.0)

        total_et = tf.concat([generated_et, real_et], axis=0)
        total_ts = tf.concat([generated_ts, real_ts], axis=0)
        total_labels = tf.concat([generated_labels, real_labels], axis=0)

        # train discriminator
        pred_prob = discriminator((total_et, total_ts))

        # cross-entropy loss
        ce_loss = tf.reduce_mean(
            tf.keras.losses.binary_crossentropy(total_labels, pred_prob, from_logits=False))
        discriminator_loss = ce_loss

        if verbose:
            print('total discriminator loss:{}'.format(discriminator_loss))
            print('-----------------------')

    grads = tape.gradient(discriminator_loss, discriminator.trainable_variables)
    optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))

    return ce_loss

# Train G and D

In [None]:
train_et = pos_event_type_seqs
train_ts = pos_timestamp_seqs

train_labels = np.ones((pos_event_type_seqs.shape[0], 1))

train_features = (train_et, train_ts)
N_train = train_et.shape[0]
N_train

In [None]:
optimizer = tf.keras.optimizers.Adam(lr=1e-3)
# optimizer = SGD(learning_rate=1e-4)
EPOCHS = 2
_TOTAL_STEPS = int(EPOCHS * N_train / BATCH_SIZE)
# _TOTAL_STEPS = 1000

dataset = create_dataset(train_features,
                         train_labels,
                         batch_size=BATCH_SIZE,
                         epochs=EPOCHS,
                         buffer_size=N_train)

gen_token_loss_history = []
gen_gaussian_loss_history = []
critic_network_loss_history = []
average_true_return_history = []

disc_ce_loss_history = []

WEIGHT_GAUSSIAN_LOSS = 1
_G_STEPS = 2
_D_STEPS = 1

In [None]:
_TOTAL_STEPS

In [None]:
step = 0

for features_batch, _ in tqdm(dataset.take(_TOTAL_STEPS)):
    step += 1
    print('Training Step:', step)
    # train the generator
    for _ in range(_G_STEPS):
        gen_token_loss, gen_gaussian_loss, critic_network_loss, average_true_return = train_generator(generator, discriminator, critic, 
                                                                                 batch_size=BATCH_SIZE, T=T, verbose=True, 
                                                                                 weight_gaussian_loss=WEIGHT_GAUSSIAN_LOSS,
                                                                                 optimizer=optimizer
                                                                                )
        gen_token_loss_history.append(gen_token_loss.numpy())
        gen_gaussian_loss_history.append(gen_gaussian_loss.numpy())    
        critic_network_loss_history.append(critic_network_loss.numpy())
        average_true_return_history.append(average_true_return.numpy())
    
    # train the discriminator
    for _ in range(_D_STEPS):
        disc_ce_loss = train_discriminator(features_batch, generator, discriminator, 
                                           batch_size=BATCH_SIZE, T=T, verbose=True, 
                                           optimizer=optimizer)
        disc_ce_loss_history.append(disc_ce_loss.numpy())

## 1000 steps: Loss over training

In [None]:
x = range(len(gen_token_loss_history))
plt.figure(dpi=100)
plt.plot(x, gen_token_loss_history)
plt.title('Generator Toke Loss History')
plt.xlabel('training steps')

x = range(len(gen_gaussian_loss_history))
plt.figure(dpi=100)
plt.plot(x, gen_gaussian_loss_history)
plt.title('Generator Gaussian Loss History')
plt.xlabel('training steps')

x = range(len(disc_ce_loss_history))
plt.figure(dpi=100)
plt.plot(x, disc_ce_loss_history)
plt.title('Discriminator CE Loss History')
plt.xlabel('training steps')

In [None]:
x = range(len(critic_network_loss_history))
plt.figure(dpi=100)
plt.plot(x, critic_network_loss_history)
plt.title('Critic Loss History')
plt.xlabel('training steps')


x = range(len(average_true_return_history))
plt.figure(dpi=100)
plt.plot(x, average_true_return_history)
plt.title('Average True Return History')
plt.xlabel('training steps')

In [None]:
loss_save_dir = '/home/lun/project-basileus/seq-gan/experiment_results/v6/after_1_epoch/loss'
if not os.path.exists(loss_save_dir):
    os.makedirs(loss_save_dir)

with open(os.path.join(loss_save_dir, 'gen_token_loss_history.pickle'), 'wb') as f:
    pickle.dump(gen_token_loss_history, f)

with open(os.path.join(loss_save_dir, 'gen_gaussian_loss_history.pickle'), 'wb') as f:
    pickle.dump(gen_gaussian_loss_history, f)
    
with open(os.path.join(loss_save_dir, 'critic_network_loss_history.pickle'), 'wb') as f:
    pickle.dump(critic_network_loss_history, f)

with open(os.path.join(loss_save_dir, 'disc_ce_loss_history.pickle'), 'wb') as f:
    pickle.dump(disc_ce_loss_history, f)

## 1000 steps: Save G and D models

In [None]:
G_save_dir = '/home/lun/project-basileus/seq-gan/experiment_results/v6/after_1_epoch/gen_weights'
if not os.path.exists(G_save_dir):
    os.makedirs(G_save_dir)
    
G_save_path = os.path.join(G_save_dir, 'gen_model.tf')
generator.save_weights(G_save_path)

In [None]:
!pwd

In [None]:
D_save_dir = '/home/lun/project-basileus/seq-gan/experiment_results/v6/after_1_epoch/disc_weights'
if not os.path.exists(D_save_dir):
    os.makedirs(D_save_dir)
    
D_save_path = os.path.join(D_save_dir, 'disc_model.tf')
discriminator.save_weights(D_save_path)

## 1000 steps: Generate sequences 

In [None]:
N_gen = 100
generated_seqs = generate_sequences(N_gen, generator, batch_size=BATCH_SIZE, T=T, recover_to_timestamp=False)

In [None]:
# generated_seqs_save_dir = './experiment_results/after_1000_steps/generated_seqs'
# if not os.path.exists(generated_seqs_save_dir):
#     os.makedirs(generated_seqs_save_dir)

# with open(os.path.join(generated_seqs_save_dir, 'generated_seqs.pickle'), 'wb') as f:
#     pickle.dump(generated_seqs, f)

In [None]:
generated_seqs

## Debug Session 

In [None]:
batch_size = BATCH_SIZE
init_token = 1
end_token = 0
max_time = 1024
G = generator

In [None]:
# Begin from dummy init state (init_token=1, init_timestamp=0.0)
curr_state_et = tf.ones([batch_size, 1, 1], dtype=tf.float64)
curr_state_ts = tf.zeros([batch_size, 1, 1], dtype=tf.float64)

all_state_et = curr_state_et
all_state_ts = curr_state_ts

episode_token_probs = tf.constant(1., dtype=tf.float64, shape=(batch_size, 1))
gaussian_log = tf.constant(0., dtype=tf.float64, shape=(batch_size, 1))

G.reset_states()

for step in range(1, T):  # sequence length
    token_prob, time_out = G([curr_state_et, curr_state_ts])

    sampled_et = tf.random.categorical(tf.math.log(token_prob), num_samples=1, dtype=tf.int32)
    sampled_et = tf.reshape(sampled_et, [batch_size, 1, 1])

    # get the chosen token probability per batch for each step
    batch_sample_et = tf.reshape(sampled_et, (batch_size, 1))
    batch_ind = tf.reshape(tf.range(0, batch_size), (batch_size, 1))
    batch_sample_et_2d = tf.concat([batch_ind, batch_sample_et], axis=1)

    sampled_token_prob = tf.reshape(tf.gather_nd(token_prob, batch_sample_et_2d), (batch_size, 1))
    episode_token_probs = tf.concat([episode_token_probs, sampled_token_prob], axis=1)
    
    # cast sampled_et into float
    sampled_et = tf.cast(sampled_et, dtype=tf.float64)
    
    # stop genererating once hit end_token
    cond_end_token = tf.equal(curr_state_et, end_token)
    curr_state_et = tf.where(cond_end_token, curr_state_et, sampled_et)
    all_state_et = tf.concat([all_state_et, curr_state_et], axis=1)

    # generate one timstamp using time_out
    sampled_ts_raw = time_out.sample()
    sampled_ts = tf.clip_by_value(tf.reshape(sampled_ts_raw, (batch_size, 1, 1))
                                  , clip_value_min=1, clip_value_max=max_time)

    # get the gaussian log likelihood for the sampled timestamps
    sampled_gaussian_log = time_out.log_prob(sampled_ts_raw)
    gaussian_log = tf.concat([gaussian_log, sampled_gaussian_log], axis=1)

    # stop generating once hit end_token
    curr_state_ts = tf.where(cond_end_token, curr_state_ts, sampled_ts)
    all_state_ts = tf.concat([all_state_ts, curr_state_ts], axis=1)

In [None]:
all_state_et

In [None]:
token_prob

In [None]:
curr_state_et

In [None]:
sampled_et

In [None]:
tf.gather_nd(token_prob, tf.reshape(sampled_et, (batch_size, 1)))

In [None]:
sample_et_2 = tf.reshape(sampled_et, (batch_size, 1))
batch_ind = tf.reshape(tf.range(0, batch_size), (batch_size, 1))
sample_ed_3 = tf.concat([batch_ind, sample_et_2], axis=1)

In [None]:
tf.gather_nd(token_prob, sample_ed_3)