In [1]:
import ujson as json
import tensorflow as tf
import numpy as np
from text_gan.data.qgen_data import QuestionContextPairs, CONFIG
from text_gan.layers import FixedEmbedding
from text_gan.models.squad_qgan import get_model


tf.debugging.set_log_device_placement(True)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  # Create 2 virtual GPUs with 1GB memory each
  try:
    # tf.config.experimental.set_virtual_device_configuration(
    #     gpus[0],
    #     [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024*3.5),])
    # logical_gpus = tf.config.experimental.list_logical_devices('GPU')
    tf.config.experimental.set_memory_growth(gpus[0], True)
    # print(logical_gpus)
  except RuntimeError as e:
    # Virtual devices must be set before GPUs have been initialized
    print(e)

In [2]:
# with tf.device("/device:GPU:0"):
data = QuestionContextPairs.load(CONFIG.SAVELOC)

Executing op Reshape in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op FlatMapDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op TensorSliceDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op FlatMapDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ParallelMapDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ParallelMapDataset in device /job:localhost/replica:0/task:0/device:CPU:0


In [3]:
data.train

<ParallelMapDataset shapes: (((256,), (256,), (32,), (16,)), (16,)), types: ((tf.int32, tf.uint8, tf.float32, tf.int32), tf.int32)>

In [8]:
train = data.train.batch(2)#.prefetch(2)

Executing op BatchDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0


In [9]:
to_gpu = tf.data.experimental.copy_to_device("/gpu:0")
train = train.apply(to_gpu)
with tf.device("/gpu:0"):
    train = train.prefetch(2)

Executing op GeneratorDataset in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:GPU:0


In [6]:
word_emb_mat = np.array(json.load(open(CONFIG.WORDEMBMAT, "r")))

In [11]:
#Context inputs
context = tf.keras.layers.Input(shape=(CONFIG.MAX_CONTEXT_LEN,), name="Context-Tokens")
discourse_markers = tf.keras.layers.Input(shape=(CONFIG.MAX_CONTEXT_LEN,), name="Context-Discourse-Markers")
latent_vector = tf.keras.layers.Input(shape=(CONFIG.LATENT_DIM,), name="Latent-Vector")

#Encoder
context_emb = FixedEmbedding(
        word_emb_mat, CONFIG.MAX_CONTEXT_LEN, name="Context-Embedding")(context)
enc_x1 = tf.keras.layers.GRU(32, return_sequences=True, name="Context-Encoder-1")(context_emb)
enc_x1, enc_x1_state = tf.keras.layers.GRU(32, return_state=True, name="Context-Encoder-2")(enc_x1)

enc_x = tf.keras.layers.Multiply()([enc_x1_state, latent_vector])

encoder_model = tf.keras.Model([context, discourse_markers, latent_vector], enc_x, name="QGAN-Enc")

#decoder
question_tokens = tf.keras.layers.Input(shape=(CONFIG.MAX_QLEN,), name="Question-Tokens")
decoder_state_input = tf.keras.layers.Input(shape=(32,), name="Decoder-state")

question_emb = tf.keras.layers.Embedding(CONFIG.QVOCAB_SIZE, 32, name="Question-Embedding")(question_tokens)
decoder_1 = tf.keras.layers.GRU(32, return_sequences=True, name="GRU-Decoder-1")
decoder_2 = tf.keras.layers.GRU(32, return_sequences=True, return_state=True, name="GRU-Decoder-2")
decoder_dense = tf.keras.layers.Dense(CONFIG.QVOCAB_SIZE, activation='softmax', name="Dense-Decoder")

dec_ypred = decoder_1(question_emb, initial_state=decoder_state_input)
dec_ypred, dec_ypred_state = decoder_2(dec_ypred)
decoder = tf.keras.Model([question_tokens, decoder_state_input], [dec_ypred, dec_ypred_state], name="QGAN-Dec")

#training model
dec_y = decoder_1(question_emb, initial_state=enc_x)
dec_y, _ = decoder_2(dec_y)
y = decoder_dense(dec_y)
model = tf.keras.Model([context, discourse_markers, latent_vector, question_tokens], y, name="QGAN-Trainer")
model.compile(tf.keras.optimizers.Adam(), 'sparse_categorical_crossentropy')
model.summary()

Executing op DeleteIterator in device /job:localhost/replica:0/task:0/device:CPU:0
Model: "QGAN-Trainer"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Context-Tokens (InputLayer)     [(None, 256)]        0                                            
__________________________________________________________________________________________________
Context-Embedding (FixedEmbeddi (None, 256, 300)     26624700    Context-Tokens[0][0]             
__________________________________________________________________________________________________
Context-Encoder-1 (GRU)         (None, 256, 32)      32064       Context-Embedding[0][0]          
__________________________________________________________________________________________________
Question-Tokens (InputLayer)    [(None, 16)]         0                                            
____

In [12]:
model.fit(train, epochs=2)

Executing op ParallelMapDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Epoch 1/2
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op VarHandleO

KeyboardInterrupt: 

In [14]:
model = get_model(CONFIG, 1e-3)

In [15]:
model.summary()

Model: "SQuAD-QGAN"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Context-IDs (InputLayer)        [(None, 256)]        0                                            
__________________________________________________________________________________________________
Question-so-far (InputLayer)    [(None, None)]       0                                            
__________________________________________________________________________________________________
Glove-embeddings (FixedEmbeddin (None, 256, 300)     26624700    Context-IDs[0][0]                
__________________________________________________________________________________________________
Ques-embs (Embedding)           (None, None, 16)     80000       Question-so-far[0][0]            
_________________________________________________________________________________________

In [16]:
model.fit(train, epochs=10)

Epoch 1/10
   3042/Unknown - 86s 28ms/step - loss: 4.5765

KeyboardInterrupt: 

In [17]:
model.save("/tf/data/model.tf")

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: /tf/data/model.tf/assets
