In [1]:
import tensorflow as tf
import numpy as np
import ujson as json

from text_gan.data import QuestionContextPairs, CONFIG
from text_gan.models import AttnGen
from text_gan.layers import Encoder, Decoder

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)

In [3]:
data = QuestionContextPairs.load(CONFIG.SAVELOC)
data

<text_gan.data.qgen_data.QuestionContextPairs at 0x7f2248898908>

In [4]:
train = data.train.batch(8)
to_gpu = tf.data.experimental.copy_to_device("/gpu:0")
train = train.apply(to_gpu)
with tf.device("/gpu:0"):
    train = train.prefetch(2)

In [5]:
word_emb_mat = np.array(json.load(open(CONFIG.WORDEMBMAT, "r")))
qword_emb_mat = np.load(f"{CONFIG.QWORDEMBMAT}.npy")
qword2idx = json.load(open(CONFIG.QWORD2IDX, "r"))
idx2qword = np.full(CONFIG.QVOCAB_SIZE, "<UNK>", dtype='object')
for word, idx in qword2idx.items():
    idx2qword[idx] = word


In [6]:
encoder = Encoder(word_emb_mat, CONFIG)
decoder = Decoder(qword_emb_mat, CONFIG)

In [7]:
for X, y in train.take(1):
    s0, hd = encoder((X[0], X[1], X[2]))
    y, st = decoder((X[3], s0, hd))
s0.shape, hd.shape, y.shape, st.shape

(TensorShape([8, 32]),
 TensorShape([8, 256, 32]),
 TensorShape([8, 16, 5000]),
 TensorShape([8, 32]))

In [6]:
model = AttnGen(word_emb_mat, qword_emb_mat, qword2idx, idx2qword, CONFIG)

In [7]:
model.model.summary()

Model: "Attn-Gen"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
Context-Tokens (InputLayer)     [(None, 256)]        0                                            
__________________________________________________________________________________________________
Context-Discourse-Markers (Inpu [(None, 256)]        0                                            
__________________________________________________________________________________________________
Latent-Vector (InputLayer)      [(None, 32)]         0                                            
__________________________________________________________________________________________________
Question-Tokens (InputLayer)    [(None, None)]       0                                            
___________________________________________________________________________________________

In [9]:
model.model.compile(
    tf.keras.optimizers.Adam(1e-3),
    tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
)

In [10]:
!export AUTOGRAPH_VERBOSITY=10

In [11]:
model.fit(train, epochs=1)

   1332/Unknown - 71s 53ms/step - loss: 4.0758

KeyboardInterrupt: 