In [1]:
import tensorflow as tf
import argparse
import numpy as np
import ujson as json
import logging

from text_gan import cfg, cfg_from_file
from text_gan.data.squad1_ca_q import Squad1_CA_Q
from text_gan.features import FastText, GloVe, NERTagger, PosTagger
from text_gan.models import QGAN, AttnGen, CA_Q_AttnQGen, CAZ_Q_Attn, CANPZ_Q

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
    except RuntimeError as e:
        print(e)

In [3]:
logging.basicConfig(
    level=cfg.LOG_LVL,
    filename=cfg.LOG_FILENAME,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')

In [32]:
def canpz_q():
    RNG_SEED = 11
    data = Squad1_CA_Q()
    data = data.train.shuffle(
        buffer_size=10000, seed=RNG_SEED, reshuffle_each_iteration=False)
    to_gpu = tf.data.experimental.copy_to_device("/gpu:0")
    train = data.skip(1010).take(50)\
        .shuffle(buffer_size=100, seed=RNG_SEED)\
        .batch(1).apply(to_gpu)
    val = data.take(1000).batch(10).apply(to_gpu)
    with tf.device("/gpu:0"):
        train = train.prefetch(2)
        val = val.prefetch(1)

    if cfg.EMBS_TYPE == 'glove':
        cembs = GloVe.load(cfg.EMBS_FILE, cfg.CSEQ_LEN, cfg.EMBS_CVOCAB)
        qembs = GloVe.load(
            cfg.EMBS_FILE, cfg.QSEQ_LEN, cfg.EMBS_QVOCAB, cembs.data)
    elif cfg.EMBS_TYPE == 'fasttext':
        cembs = FastText.load(cfg.EMBS_FILE, cfg.CSEQ_LEN, cfg.EMBS_CVOCAB)
        qembs = FastText.load(
            cfg.EMBS_FILE, cfg.QSEQ_LEN, cfg.EMBS_QVOCAB, cembs.data)
    else:
        raise ValueError(f"Unsupported embeddings type {cfg.EMBS_TYPE}")
    ner = NERTagger(cfg.NER_TAGS_FILE, cfg.CSEQ_LEN)
    pos = PosTagger(cfg.POS_TAGS_FILE, cfg.CSEQ_LEN)

    model = CANPZ_Q(cembs, ner, pos, qembs)
    model.load('/tf/data/canpz_q/')
    pred, attn_weights = model.predict(train)
    i = 0
    cont = []
    for X, y in train:
        context = list(cembs.inverse_transform(X[0].numpy())[0])
        answer = tf.reshape(X[0]*tf.cast(X[1], tf.int32), (-1,))
        ogques = qembs.inverse_transform(y.numpy())[0]
        ans = ''
        for ai in answer:
            if ai == 0:
                continue
            ans += cembs.inverse.get(ai.numpy(), cembs.UNK) + ' '
        # context = list(filter(
        #     lambda w: w != cembs.PAD, context))
        try:
            ogques = ogques[:ogques.index(qembs.END)]
        except:
            pass
        ques = qembs.inverse_transform([pred[i].numpy()])[0]
        try:
            ques = ques[:ques.index(qembs.END)]
        except:
            pass
        print(f"Context:- {' '.join(context)}")
        print(f"Answer:- {ans}")
        print(f"OG Question:- {' '.join(ogques)}")
        print(f"Question:- {' '.join(ques)}")
        print(f"Attention Weights:- {attn_weights[i].numpy()}")
        print("")
        cont.append(context)
        i += 1
    return attn_weights, cont

In [33]:
attn_weights, cont = canpz_q()

hat may or may not exist today . Ethnohistory uses both historical and ethnographic data as its foundation . Its historical methods and materials go beyond the standard use of documents and manuscripts . Practitioners recognize the utility of such source material as maps , music , paintings , photography , folklore , oral tradition , site exploration , archaeological materials , museum collections , enduring customs , language , and place names . EOS PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD PAD P

In [37]:
attn_weights[0]

<tf.Tensor: shape=(21, 250), dtype=float32, numpy=
array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.00380346, 0.00375432, 0.00374713, ..., 0.0030784 , 0.00304931,
        0.00529373],
       [0.00397695, 0.0039255 , 0.00391625, ..., 0.00347874, 0.00347357,
        0.00543915],
       ...,
       [0.00396752, 0.0039199 , 0.00392636, ..., 0.00317718, 0.0031203 ,
        0.00481249],
       [0.00391109, 0.00386831, 0.00385672, ..., 0.00351122, 0.00350127,
        0.00537726],
       [0.00396752, 0.0039199 , 0.00392636, ..., 0.00317719, 0.00312031,
        0.00481251]], dtype=float32)>

In [65]:
for i in range(1, 20):
    indices = attn_weights[0][i].numpy().argsort()[::-1]
    print(attn_weights[0][i].numpy()[indices[:10]], np.array(cont[i])[indices[:10]])

[0.00529373 0.00468969 0.00464408 0.00461993 0.00459947 0.00458514
 0.00458396 0.00458133 0.00458111 0.00457985] ['PAD' 'Pima' ',' 'PAD' 'and' '.' ',' 'Lakecrest' 'Neighborhoods' 'The']
[0.00543915 0.00446042 0.00445024 0.00438882 0.00438667 0.00438279
 0.00438044 0.00437524 0.00436722 0.00436267] ['PAD' 'PAD' 'PAD' 'PAD' 'PAD' 'PAD' '.' 'Opération' '2013' 'in']
[0.00438997 0.00436893 0.00435375 0.00431877 0.00431229 0.00431209
 0.00430922 0.00430878 0.00430542 0.00430114] ['as' 'known' 'overhunts' 'This' '.' 'phenomenon' 'of' 'coextinction' 'a'
 'the']
[0.00555489 0.00459274 0.00458958 0.00454261 0.00452969 0.00449071
 0.00448709 0.00447945 0.00445896 0.00444994] ['PAD' 'PAD' 'PAD' 'PAD' 'PAD' 'PAD' 'PAD' 'PAD' 'PAD' 'PAD']
[0.00447028 0.00446819 0.0044499  0.00444041 0.00443309 0.00442436
 0.00441406 0.00441297 0.00440605 0.00440161] ['oriented' 'a' 'the' 'sizeable' 'Sufi' 'BJP' 'portion' 'up' 'ruling' 'of']
[0.0045096  0.00450464 0.0044787  0.0044557  0.00444778 0.00444106
 0.004439

In [53]:
attn_weights[0].shape

TensorShape([21, 250])

In [51]:
indices.shape

(21, 250)

In [31]:
cont[0]

['Several',
 'Islamic',
 'kingdoms',
 '(',
 'sultanates',
 ')',
 'under',
 'both',
 'foreign',
 'and',
 ',',
 'newly',
 'converted',
 ',',
 'Rajput',
 'rulers',
 'were',
 'established',
 'across',
 'the',
 'north',
 'western',
 'subcontinent',
 '(',
 'Afghanistan',
 'and',
 'Pakistan',
 ')',
 'over',
 'a',
 'period',
 'of',
 'a',
 'few',
 'centuries',
 '.',
 'From',
 'the',
 '10th',
 'century',
 ',',
 'Sindh',
 'was',
 'ruled',
 'by',
 'the',
 'Rajput',
 'Soomra',
 'dynasty',
 ',',
 'and',
 'later',
 ',',
 'in',
 'the',
 'mid-13th',
 'century',
 'by',
 'the',
 'Rajput',
 'Samma',
 'dynasty',
 '.',
 'Additionally',
 ',',
 'Muslim',
 'trading',
 'communities',
 'flourished',
 'throughout',
 'coastal',
 'south',
 'India',
 ',',
 'particularly',
 'on',
 'the',
 'western',
 'coast',
 'where',
 'Muslim',
 'traders',
 'arrived',
 'in',
 'small',
 'numbers',
 ',',
 'mainly',
 'from',
 'the',
 'Arabian',
 'peninsula',
 '.',
 'This',
 'marked',
 'the',
 'introduction',
 'of',
 'a',
 'third',
 'A