In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras import layers
from tensorflow.keras.callbacks import *
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split

In [2]:
import os
import sys
import time
import pandas as pd
import numpy as np
from tqdm import tqdm
from pickle import dump, load
import gc

In [3]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [4]:
BATCH_SIZE = 512
WIN_SIZE = 100

In [5]:
def read_feather(file_name = "train.feather"):
    data = pd.read_feather(file_name)
    return data

In [6]:
user_dict = load(open('/content/drive/MyDrive/transformer_train/user_dict.pkl', 'rb'))

In [7]:
tdf = read_feather('/content/drive/MyDrive/transformer_train/all_train_dat_plus.feather')

In [8]:
tdf.numlect = tdf.numlect.astype(np.uint8)
tdf.task_container_id[tdf.task_container_id > 2000] = 2000

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  


In [9]:
tdf.drop(columns = ['timestamp', 'user_id'], inplace = True)

In [10]:
def rolling_window(a, w):
    s0, s1 = a.strides
    m, n = a.shape
    return np.lib.stride_tricks.as_strided(a, 
                                           shape=(m-w+1, w, n), 
                                           strides=(s0, s0, s1))

def make_time_series(x, windows_size, pad_size=0):
    x = np.pad(x, [[ windows_size-pad_size-1, 0], [0, 0]], constant_values=0)
    x = rolling_window(x, windows_size)
    return list(x)

def shift_answer(df):
    # We add one to the column in order to have zeros as padding values
    # Start Of Sentence (SOS) token will be 3. 
    df['answered_correctly'] = df['answered_correctly'].shift(fill_value=2)+1
    return df

In [11]:
def dat_generator(df, user_dict, for_train = True, val_p1 = None, val_p2 = None, win_size = 128, batch_size = 256):
    pos_to_uid = list(user_dict.keys())
    pos_to_uid.sort()
    if val_p2 is None: val_p2 = int(len(user_dict) * 0.9)
    if val_p1 is None: val_p1 = int(val_p2 * 0.5)
    if for_train:
        remain_uid = list(range(val_p2))
        rem_uid = None
        use_dat = None
        rem_uidx = None
        while True:
            X = []
            y = []
            remain_batch = batch_size
            if rem_uid is not None: 
                if uid < pos_to_uid[val_p1]: utotal_dat = min(3000, len(user_dict[rem_uid]))
                else: utotal_dat = min(3000, int(len(user_dict[rem_uid])*0.8))
                u_dat = utotal_dat - use_dat
                if u_dat <= remain_batch:
                    y.extend(df.answered_correctly[user_dict[rem_uid][use_dat]:user_dict[rem_uid][utotal_dat-1]+1])
                    t_use_dat = max(0, use_dat - win_size + 1)
                    t_use = max(0, t_use_dat - 1)
                    pad_size = use_dat
                    if t_use_dat != 0: pad_size = win_size - 1
                    udf = df[user_dict[rem_uid][t_use]:user_dict[rem_uid][utotal_dat-1]+1].copy()
                    udf = shift_answer(udf)
                    if t_use_dat != 0: udf = udf[1:].copy()
                    X.extend(make_time_series(udf, win_size, pad_size = pad_size))
                    remain_batch -= u_dat
                    remain_uid.pop(rem_uidx)
                    rem_uid = None
                    rem_uidx = None
                    use_dat = None
                else:
                    y.extend(df.answered_correctly[user_dict[rem_uid][use_dat]:user_dict[rem_uid][use_dat+remain_batch]])
                    t_use_dat = max(0, use_dat - win_size + 1)
                    t_use = max(0, t_use_dat - 1)
                    pad_size = use_dat
                    if t_use_dat != 0: pad_size = win_size - 1
                    udf = df[user_dict[rem_uid][t_use]:user_dict[rem_uid][use_dat+remain_batch]].copy()
                    udf = shift_answer(udf)
                    if t_use_dat != 0: udf = udf[1:].copy()
                    X.extend(make_time_series(udf, win_size, pad_size))
                    use_dat += remain_batch
                    remain_batch = 0
            while remain_batch > 0:
                if len(remain_uid)==0: remain_uid = list(range(val_p2))
                uidx = np.random.choice(len(remain_uid), 1)[0]
                uid = pos_to_uid[remain_uid[uidx]]
                u_dat = min(3000, len(user_dict[uid]))
                if u_dat < 20:
                    remain_uid.pop(uidx)
                    continue
                if uid > pos_to_uid[val_p1]: u_dat = min(3000, int(len(user_dict[uid])*0.8))
                if u_dat <= remain_batch:
                    y.extend(df.answered_correctly[user_dict[uid][0]:user_dict[uid][u_dat-1]+1])
                    udf = df[user_dict[uid][0]:user_dict[uid][u_dat-1]+1].copy()
                    udf = shift_answer(udf)
                    X.extend(make_time_series(udf, win_size))
                    remain_batch -= u_dat
                    remain_uid.pop(uidx)
                else:
                    y.extend(df.answered_correctly[user_dict[uid][0]:user_dict[uid][remain_batch]])
                    udf = df[user_dict[uid][0]:user_dict[uid][remain_batch]].copy()
                    udf = shift_answer(udf)
                    X.extend(make_time_series(udf, win_size))
                    rem_uid = uid
                    rem_uidx = uidx
                    use_dat = remain_batch
                    remain_batch = 0
            yield np.asarray(X).astype(np.float32), np.asarray(y).astype(np.float32)
    else:
        remain_uid = []
        rem_uid = None
        use_dat = None
        rem_uidx = None
        while True:
            X = []
            y = []
            if len(remain_uid) == 0: remain_uid = list(range(val_p1, len(user_dict)))
            uidx = np.random.choice([0, len(remain_uid) - 1], 1)[0]
            uid = pos_to_uid[remain_uid[uidx]]
            if uid > pos_to_uid[val_p2]:
                u_dat = len(user_dict[uid])
                start_token = - u_dat
                end_token = -1
                if u_dat > BATCH_SIZE: end_token = -u_dat + BATCH_SIZE
            else:
                if len(user_dict[uid]) < 10:
                    remain_uid.pop(uidx)
                    continue 
                u_dat = int(len(user_dict[uid])*0.2)
                start_token = - u_dat - 1
                end_token = -1
                if u_dat > BATCH_SIZE: end_token = -u_dat + BATCH_SIZE
            y.extend(df.answered_correctly[user_dict[uid][-u_dat]:user_dict[uid][end_token]+1])
            udf = df[user_dict[uid][start_token]:user_dict[uid][end_token]+1].copy()
            udf = shift_answer(udf)
            if udf.shape[0] != len(y): udf = udf[1:].copy()
            X.extend(make_time_series(udf, win_size))
            remain_uid.pop(uidx)
            yield np.asarray(X).astype(np.float32), np.asarray(y).astype(np.float32)

In [12]:
class MultiHeadAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads=8, **kwargs):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if self.embed_dim % self.num_heads != 0:
            raise ValueError(
                f"embedding dimension = {self.embed_dim} should be divisible by number of heads = {self.num_heads}"
            )
        self.projection_dim = self.embed_dim // self.num_heads
        self.query_dense = layers.Dense(embed_dim)
        self.key_dense = layers.Dense(embed_dim)
        self.value_dense = layers.Dense(embed_dim)
        self.combine_heads = layers.Dense(embed_dim)
    
    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
        })
        return cfg

    def attention(self, query, key, value, mask):
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], tf.float32)
        scaled_score = score / tf.math.sqrt(dim_key)
        if mask is not None:
            scaled_score += (mask * -1e9)
        weights = tf.nn.softmax(scaled_score, axis=-1)
        output = tf.matmul(weights, value)
        return output, weights

    def separate_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, q , k ,v, mask):
        batch_size = tf.shape(q)[0]
        print("batch size", batch_size)
        query = self.query_dense(q)  # (batch_size, seq_len, embed_dim)
        key = self.key_dense(k)  # (batch_size, seq_len, embed_dim)
        value = self.value_dense(v)  # (batch_size, seq_len, embed_dim)
        query = self.separate_heads(query, batch_size)  # (batch_size, num_heads, seq_len, projection_dim)
        key = self.separate_heads(key, batch_size)  # (batch_size, num_heads, seq_len, projection_dim)
        value = self.separate_heads(value, batch_size)  # (batch_size, num_heads, seq_len, projection_dim)
        attention, weights = self.attention(query, key, value, mask)
        attention = tf.transpose(attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len, num_heads, projection_dim)
        concat_attention = tf.reshape(attention, (batch_size, -1, self.embed_dim))  # (batch_size, seq_len, embed_dim)
        output = self.combine_heads(concat_attention)  # (batch_size, seq_len, embed_dim)
        return output # can return weights

"""
Encoder block as a layer
"""

class EncoderBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim = None, rate=0.1, **kwargs):
        super(EncoderBlock, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.rate = rate
        self.att = MultiHeadAttention(self.embed_dim, self.num_heads)
        if self.ff_dim is None: self.ff_dim = 2*self.embed_dim
        self.ffn = tf.keras.Sequential(
            [layers.Dense(self.ff_dim, activation="relu"), layers.Dense(self.embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(self.rate)
        self.dropout2 = layers.Dropout(self.rate)
        
    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'ff_dim': self.ff_dim,
            'rate': self.rate
        })
        return cfg

    def call(self, x, y, padding_mask, training):
        attn_output = self.att(x, y, y, padding_mask)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(x + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

"""
Decoder block as a layer
"""

class DecoderBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim = None, rate = 0.1, **kwargs):
        super(DecoderBlock, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.rate = rate
        self.att1 = MultiHeadAttention(self.embed_dim, self.num_heads)
        self.att2 = MultiHeadAttention(self.embed_dim, self.num_heads)
        self.ffn = tf.keras.Sequential(
            [layers.Dense(self.ff_dim, activation="relu"), layers.Dense(self.embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(self.rate)
        self.dropout2 = layers.Dropout(self.rate)
        self.dropout3 = layers.Dropout(self.rate)
        
    def get_config(self):
        cfg = super().get_config()
        cfg.update({
            'embed_dim': self.embed_dim,
            'num_heads': self.num_heads,
            'ff_dim': self.ff_dim,
            'rate': self.rate
        })
        return cfg
    
    def call(self, x, enc_output, look_ahead_mask, padding_mask, training):
        attn1 = self.att1(x, x, x, look_ahead_mask)
        attn1 = self.dropout1(attn1, training = training)
        out1 = self.layernorm1(attn1 + x)
        
        attn2 = self.att2(out1, enc_output, enc_output, padding_mask)
        attn2 = self.dropout2(attn2, training = training)
        out2 = self.layernorm2(attn2 + out1)
        
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training = training)
        return self.layernorm3(ffn_output + out2)

In [13]:
def create_padding_mask(seqs):
    mask = tf.cast(tf.reduce_all(tf.math.equal(seqs, 0), axis=-1), tf.float32)
    return mask[:, tf.newaxis, tf.newaxis, :]

def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len)

def get_angles(pos, i, embed_dim):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(embed_dim))
    return pos * angle_rates

def positional_encoding(position, embed_dim):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                            np.arange(embed_dim)[np.newaxis, :],
                            embed_dim)
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)

In [14]:
def custom_transformer_model(feature_dim, window_size, q_size=13524, embed_dim = 256, num_heads = 16, dense_dim = 1024):
    inputs = layers.Input(shape=(window_size, feature_dim), name = "enc_input")
    min_delta = inputs[...,0]
    day_delta = inputs[...,1]
    month_delta = inputs[...,2]
    cid = inputs[...,3]
    tid = inputs[...,4]
    prior_elapsed = inputs[...,5]
    prior_explained = inputs[...,6]
    is_with = inputs[...,7]
    num_lect = inputs[...,-1,8]
    lec_type = inputs[...,-1,9:13]
    lec_h_past = inputs[...,-1,13]
    c_part = inputs[...,14:23]
    tag1 = inputs[...,23]
    tag2 = inputs[...,24]
    tag3 = inputs[...,25]
    tag4 = inputs[...,26]
    tag5 = inputs[...,27]
    tag6 = inputs[...,28]
    prev_answered_correct = inputs[...,29]
    
    #====Excercise====
    min_delta = layers.Embedding(input_dim=1443, output_dim=embed_dim//8, input_length=window_size,
                                 embeddings_initializer = 'glorot_uniform')(min_delta)
    day_delta = layers.Embedding(input_dim=33, output_dim=embed_dim//16, input_length=window_size,
                                 embeddings_initializer = 'glorot_uniform')(day_delta)
    month_delta = layers.Embedding(input_dim=9, output_dim=embed_dim//16, input_length=window_size,
                                   embeddings_initializer = 'glorot_uniform')(month_delta)
    cid = layers.Embedding(input_dim=q_size, output_dim=embed_dim, input_length=window_size,
                           embeddings_initializer = 'glorot_uniform')(cid)
    tid = layers.Embedding(input_dim=2001, output_dim=embed_dim//16, input_length=window_size,
                           embeddings_initializer = 'glorot_uniform')(tid)
    is_with = layers.Embedding(input_dim=3, output_dim=2, input_length=window_size,
                               embeddings_initializer = 'glorot_uniform')(is_with)
    c_part = layers.Dense(embed_dim//4, activation = 'relu', use_bias=False)(c_part)
#     tag_emb = layers.Embedding(input_dim=189, output_dim=embed_dim//4)
    tag1 = layers.Embedding(input_dim=189, output_dim=embed_dim//8, input_length=window_size,
                            embeddings_initializer = 'glorot_uniform')(tag1)
    tag2 = layers.Embedding(input_dim=179, output_dim=embed_dim//8, input_length=window_size,
                            embeddings_initializer = 'glorot_uniform')(tag2)
    tag3 = layers.Embedding(input_dim=162, output_dim=embed_dim//8, input_length=window_size,
                            embeddings_initializer = 'glorot_uniform')(tag3)
#     tag4 = tag_emb(tag4)
#     tag5 = tag_emb(tag5)
#     tag6 = tag_emb(tag6)
    enc_ex = layers.Concatenate()([min_delta, day_delta, month_delta, tid, c_part,
                                 tag1, tag2, tag3, is_with]) #tag4, tag5, tag6
    enc_ex = layers.Dense(embed_dim, activation = 'relu')(enc_ex)
    
    #====Lecture====
    num_lect = layers.Embedding(input_dim=160, output_dim=embed_dim//16,
                                embeddings_initializer = 'glorot_uniform')(num_lect)
    lec_type = layers.Dense(embed_dim//8, activation = 'relu', use_bias=False)(lec_type)
    lec_h_past = layers.Embedding(input_dim=724, output_dim=embed_dim//8,
                                  embeddings_initializer = 'glorot_uniform')(lec_h_past)
    enc_lec = layers.Concatenate()([num_lect, lec_type, lec_h_past])
    enc_lec = layers.Dense(embed_dim//2, activation = 'relu')(enc_lec)
    enc_lec = layers.Dropout(0.1)(enc_lec)

    #====Response====
    prev_answered_correct = layers.Embedding(input_dim=4, output_dim=embed_dim, input_length=window_size,
                                             embeddings_initializer = 'glorot_uniform')(prev_answered_correct)
    prior_elapsed = layers.Embedding(input_dim=302, output_dim=embed_dim//4, input_length=window_size,
                                     embeddings_initializer = 'glorot_uniform')(prior_elapsed)
    prior_explained = layers.Embedding(input_dim=3, output_dim=embed_dim//4, input_length=window_size,
                                       embeddings_initializer = 'glorot_uniform')(prior_explained)
    prior_inter = layers.Concatenate()([prior_elapsed, prior_explained])
    prior_inter = layers.Dense(embed_dim, activation = 'relu')(prior_inter)
    
    #====Mask====
    padding_mask = create_padding_mask(inputs)
    look_ahead_mask = create_look_ahead_mask(window_size)
    dec_combined_mask = tf.maximum(padding_mask, look_ahead_mask)
    pos_enc = positional_encoding(window_size, embed_dim)
    
    #++++Model++++
    e_enc_input = layers.Add()([cid, pos_enc, enc_ex])
    dec_input = layers.Add()([prev_answered_correct, pos_enc, prior_inter])
    
    x1 = EncoderBlock(embed_dim, num_heads, ff_dim = dense_dim)(e_enc_input, e_enc_input, padding_mask)
    x1 = EncoderBlock(embed_dim, num_heads, ff_dim = dense_dim)(x1, x1, padding_mask)
    x1 = layers.Add()([e_enc_input, x1])
    x3 = DecoderBlock(embed_dim, num_heads, ff_dim = dense_dim)(dec_input, x1,
                                                                dec_combined_mask, padding_mask)
    x3 = DecoderBlock(embed_dim, num_heads, ff_dim = dense_dim)(x3, x1,
                                                                dec_combined_mask, padding_mask)
    x = x3[:, -1, :]
    x = layers.Concatenate()([x, enc_lec])
    x = layers.Dense(embed_dim, activation="relu")(x)
#     x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(1, activation="sigmoid",
                           kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
                           name = "output")(x)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

In [15]:
val_step = int(tdf.shape[0]*0.2 / BATCH_SIZE)
train_step = int(tdf.shape[0]*0.8 / BATCH_SIZE)

In [16]:
model = custom_transformer_model(30, WIN_SIZE, embed_dim = 128, dense_dim = 512, num_heads=8)
model.summary()
model.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer='adam',
              run_eagerly=True, metrics=['AUC', 'acc'])

batch size Tensor("encoder_block/multi_head_attention/strided_slice:0", shape=(), dtype=int32)
batch size Tensor("encoder_block_1/multi_head_attention_1/strided_slice:0", shape=(), dtype=int32)
batch size Tensor("decoder_block/multi_head_attention_2/strided_slice:0", shape=(), dtype=int32)
batch size Tensor("decoder_block/multi_head_attention_3/strided_slice:0", shape=(), dtype=int32)
batch size Tensor("decoder_block_1/multi_head_attention_4/strided_slice:0", shape=(), dtype=int32)
batch size Tensor("decoder_block_1/multi_head_attention_5/strided_slice:0", shape=(), dtype=int32)
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
enc_input (InputLayer)          [(None, 100, 30)]    0                                            
__________________________________________________________________________________________________
tf.

In [17]:
model_name = "SSAKT_2"
model_folder = "ckp_dir/" + model_name
tblogs = "tbl_dir"
if not os.path.exists(model_folder):
    os.makedirs(model_folder)
if not os.path.exists(tblogs):
    os.mkdir(tblogs)

# model_checkpoints = ModelCheckpoint('{}/{}'.format(model_folder, model_name)+'-{epoch:02d}_{loss:.4f}_{auc:.4f}_{val_loss:.4f}_{val_auc:.4f}.h5', save_best_only=True, monitor='val_auc', mode='max')
model_checkpoints = ModelCheckpoint('{}/{}'.format(model_folder, model_name)+'.{epoch:02d}_{loss:.4f}_{auc:.4f}_{val_loss:.4f}_{val_auc:.4f}.h5')
lr_auto = ReduceLROnPlateau(monitor = "val_loss", factor = 0.1, patience = 7, mode = "min", min_delta = 0.0001, min_lr = 0.0000001, verbose = 1)
log_dir = "{}/{}-{}".format(tblogs, model_name, time.time())
tensorboard = TensorBoard(log_dir=log_dir)

In [18]:
train_gen = dat_generator(tdf, user_dict, win_size = WIN_SIZE, batch_size=BATCH_SIZE)
val_gen = dat_generator(tdf, user_dict, for_train = False, win_size = WIN_SIZE, batch_size=BATCH_SIZE)

In [19]:
model.fit(train_gen, epochs = 5, initial_epoch = 0,
          callbacks=[model_checkpoints, lr_auto, tensorboard],
          validation_data = val_gen, steps_per_epoch = train_step, validation_steps = val_step,
          )#class_weight=class_weight

Epoch 1/3
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
     1/155111 [..............................] - ETA: 68:16:58 - loss: 0.7120 - auc: 0.5034 - acc: 0.3906batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
     2/155111 [..............................] - ETA: 19:12:40 - loss: 0.7043 - auc: 0.4875 - acc: 0.4438batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int32)
batch size tf.Tensor(512, shape=(), dtype=int

KeyboardInterrupt: ignored