In [47]:
import numpy as np
import json
import torch
import itertools
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [51]:
# dataset = np.load('../data/train.npz')

In [52]:
qw_idxs = torch.load('../data_dev/qw_idxs.pt') # torch.Size([64, 29])

In [53]:
# batch_size, len, 
qw_idxs.shape

torch.Size([64, 29])

In [54]:
with open('../data/word_idx2syll_idx.json') as json_file:  
    word_idx2syll_idx = json.load(json_file)
word_idx2syll_idx['534']

[2200]

In [55]:
with open('../data/syll2idx.json') as json_file:  
    syll2idx = json.load(json_file)

In [56]:
qw_idxs.shape

torch.Size([64, 29])

In [57]:
orig_len = max(len(sent) for sent in qw_idxs)

29

In [58]:
qw_idxs

tensor([[    1,   191,   534,  ...,     0,     0,     0],
        [    1,   191,    12,  ...,     0,     0,     0],
        [    1,    24, 51294,  ...,     0,     0,     0],
        ...,
        [    1,   999,  2150,  ...,     0,     0,     0],
        [    1,  2461,  4003,  ...,     0,     0,     0],
        [    1,   999,  1172,  ...,     0,     0,     0]])

In [60]:
q_mask = torch.zeros_like(qw_idxs) != qw_idxs
q_len = q_mask.sum(-1)
q_len

tensor([ 8,  7, 11, 15, 11, 13,  9,  9,  9, 13, 14, 10, 18,  6, 11, 13,  8,  8,
        11, 19, 11, 16, 13, 12, 12, 10,  9,  8, 13, 14, 16, 10, 17, 29, 11, 13,
        12, 20, 10,  6, 18,  7, 12,  7, 10, 17, 16, 10, 19, 11, 14, 10,  9, 13,
        15,  9, 14, 15, 16, 13, 13, 14, 12,  7])

In [61]:
import itertools

def word2syll_idxs(w_idxs, word_idx2syll_idx, pad_idx=0, unk_idx=1):
    """
    Convert a list tensor of word indices to syllable indices.
    @param w_idxs (Tensor): batch of word indices (context or question)
    @param word_idx2syll_idx (Dict[str->list[int]]): mapping vocabulary index to syllable index
    @param unk_idx (int): index to padding token (0)
    """
    
    syll_idxs = []
    max_word_len = 0
    for sent in w_idxs:
        syll_idx = [word_idx2syll_idx[str(i)] if str(i) in word_idx2syll_idx else [pad_idx] if str(i) == '0' else [unk_idx] for i in sent.tolist()]
        syll_idxs.append(syll_idx)
        max_tmp = max(len(idxs) for idxs in syll_idx)
        max_word_len = max(max_word_len, max_tmp)
    
    # pad word to max word length (measured in syllables)
    for i, sent in enumerate(syll_idxs):
        for j, word in enumerate(sent):
            syll_idxs[i][j] += [pad_idx] * (max_word_len - len(word))
    
    # pad sentence to max sentence length (measured in words)
    max_sen_len = max(len(sent) for sent in syll_idxs)
    for i, sent in enumerate(syll_idxs):
        for _ in range(max_sen_len - len(sent)):
            syll_idxs[i].append([pad_idx] * max_word_len)
    return torch.LongTensor(syll_idxs)

syl_tensor = word2syll_idxs(qw_idxs, word_idx2syll_idx)
syl_tensor.shape

torch.Size([64, 29, 6])

In [62]:
syl_tensor.shape

torch.Size([64, 29, 6])

In [72]:
# Sort by length and pack sequence for RNN
lengths, sort_idx = q_len.sort(0, descending=True)
x = syl_tensor[sort_idx]     # (batch_size, seq_len, input_size)
        
x = pack_padded_sequence(x, lengths, batch_first=True)

In [73]:
x

PackedSequence(data=tensor([[ 1,  0,  0,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0],
        ...,
        [ 1,  0,  0,  0,  0,  0],
        [10, 11, 12,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0]]), batch_sizes=tensor([64, 64, 64, 64, 64, 64, 62, 58, 54, 48, 41, 34, 29, 20, 15, 12,  8,  6,
         4,  2,  1,  1,  1,  1,  1,  1,  1,  1,  1]))

In [74]:
x

PackedSequence(data=tensor([[ 1,  0,  0,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0],
        ...,
        [ 1,  0,  0,  0,  0,  0],
        [10, 11, 12,  0,  0,  0],
        [ 1,  0,  0,  0,  0,  0]]), batch_sizes=tensor([64, 64, 64, 64, 64, 64, 62, 58, 54, 48, 41, 34, 29, 20, 15, 12,  8,  6,
         4,  2,  1,  1,  1,  1,  1,  1,  1,  1,  1]))

In [75]:
x, _ = pad_packed_sequence(x, batch_first=True, total_length=29)
_, unsort_idx = sort_idx.sort(0)
x = x[unsort_idx]
x.shape

torch.Size([64, 29, 6])

In [76]:
x[0]

tensor([[   1,    0,    0,    0,    0,    0],
        [  71,    0,    0,    0,    0,    0],
        [2200,    0,    0,    0,    0,    0],
        [  56,    0,    0,    0,    0,    0],
        [ 152,  190,   25,    0,    0,    0],
        [   4,    0,    0,    0,    0,    0],
        [   1, 2465,    0,    0,    0,    0],
        [   1,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,  

In [77]:
syl_tensor[0]

tensor([[   1,    0,    0,    0,    0,    0],
        [  71,    0,    0,    0,    0,    0],
        [2200,    0,    0,    0,    0,    0],
        [  56,    0,    0,    0,    0,    0],
        [ 152,  190,   25,    0,    0,    0],
        [   4,    0,    0,    0,    0,    0],
        [   1, 2465,    0,    0,    0,    0],
        [   1,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,    0,    0],
        [   0,    0,    0,    0,  

#### Debug

In [35]:
with open('../data/word2idx.json') as json_file:  
    word2idx = json.load(json_file)

In [36]:
idx2syll = {val : key for key, val in syll2idx.items()}
idx2word = {val : key for key, val in word2idx.items()}

In [37]:
word_idx2syll_idx['534']

[2200, 0, 0, 0, 0, 0]

In [38]:
sent = qw_idxs[0]

In [39]:
word2idx['type']

534

In [40]:
idx2word[534]

'type'

In [41]:
[idx2word[i] for i in sent.tolist()]

['--OOV--',
 'What',
 'type',
 'of',
 'area',
 'is',
 'Sichuan',
 '?',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--',
 '--NULL--']

In [42]:
syll_idxs = [word_idx2syll_idx[str(i)] for i in sent.tolist() if str(i) in word_idx2syll_idx]
# syll_idxs = list(itertools.chain.from_iterable(syll_idxs))
syll_idxs

[[71, 0, 0, 0, 0, 0],
 [2200, 0, 0, 0, 0, 0],
 [56, 0, 0, 0, 0, 0],
 [152, 190, 25, 0, 0, 0],
 [4, 0, 0, 0, 0, 0],
 [1, 2465, 0, 0, 0, 0]]

In [43]:
[[idx2syll[syll_idx] for syll_idx in word_idx2syll_idx[str(i)]] for i in sent.tolist() if str(i) in word_idx2syll_idx]

[['waht', '--NULL--', '--NULL--', '--NULL--', '--NULL--', '--NULL--'],
 ['tayp', '--NULL--', '--NULL--', '--NULL--', '--NULL--', '--NULL--'],
 ['ahv', '--NULL--', '--NULL--', '--NULL--', '--NULL--', '--NULL--'],
 ['eh', 'riy', 'ah', '--NULL--', '--NULL--', '--NULL--'],
 ['ihz', '--NULL--', '--NULL--', '--NULL--', '--NULL--', '--NULL--'],
 ['--OOV--', 'waan', '--NULL--', '--NULL--', '--NULL--', '--NULL--']]

In [44]:
qs_idxs = []
for sent in qw_idxs:
    syll_idxs = [word_idx2syll_idx[str(i)] for i in sent.tolist() if str(i) in word_idx2syll_idx]
#     syll_idxs = list(itertools.chain.from_iterable(syll_idxs))
    qs_idxs.append(syll_idxs)

In [45]:
max_len = max(len(idxs) for idxs in qs_idxs)
max_len

24