In [1]:
import sys
import os

In [2]:
import dataloader
import vanilla_model
import attention_model

In [3]:
os.listdir('../data')

['hi.translit.sampled.dev.tsv',
 'hi.translit.sampled.test.tsv',
 'hi.translit.sampled.train.tsv']

In [4]:
BASE_PATH = '../data'

In [5]:
ddm = dataloader.DakshinaDataModule(
    train_file = os.path.join(BASE_PATH, 'hi.translit.sampled.train.tsv'),
    val_file = os.path.join(BASE_PATH, 'hi.translit.sampled.dev.tsv'),
    test_file = os.path.join(BASE_PATH, 'hi.translit.sampled.test.tsv'),
)

In [6]:
ddm.setup()

In [7]:
ddm.train_dataset[0]

{'src': tensor([1, 3, 4, 2]), 'tgt': tensor([1, 3, 4, 2])}

In [8]:
for i, batch in enumerate(ddm.train_dataloader()):
    if i == 2:
        break
    print(f'Batch {i}:')
    for k, v in batch.items():
        print(f'  {k}: {v.shape}')

Batch 0:
  src_input: torch.Size([32, 16])
  src_len: torch.Size([32])
  tgt_input: torch.Size([32, 17])
  tgt_len: torch.Size([32])
  tgt_output: torch.Size([32, 17])
Batch 1:
  src_input: torch.Size([32, 15])
  src_len: torch.Size([32])
  tgt_input: torch.Size([32, 16])
  tgt_len: torch.Size([32])
  tgt_output: torch.Size([32, 16])


In [9]:
batch0 = next(iter(ddm.train_dataloader()))
batch0['src_input']

tensor([[ 1, 14, 12, 18, 13, 20, 57,  2,  0,  0,  0,  0,  0,  0,  0],
        [ 1,  5, 20, 31, 18, 23, 20,  9, 18, 28,  5,  2,  0,  0,  0],
        [ 1, 28, 24, 28, 40,  9, 18,  9, 21,  2,  0,  0,  0,  0,  0],
        [ 1, 29, 18, 12,  8, 34, 10,  2,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 36,  4, 26, 21, 31,  2,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 33, 42, 36, 18, 34,  8, 29, 10,  2,  0,  0,  0,  0,  0],
        [ 1, 36, 28, 36, 18, 23, 20,  6, 18, 12, 36, 18,  9,  2,  0],
        [ 1,  6, 11,  4, 30, 20,  2,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 23, 20, 15, 16,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 14, 18, 31,  8, 36, 18,  9, 12, 21, 23,  2,  0,  0,  0],
        [ 1, 36, 20, 23, 28, 15,  2,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 36, 11, 27,  8,  4,  9,  8,  9,  2,  0,  0,  0,  0,  0],
        [ 1, 34, 11, 12, 20,  9, 20,  9, 18,  9, 18, 31,  8,  5,  2],
        [ 1, 27, 25, 28, 15, 16,  2,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 1, 28, 33,

In [10]:
seq_model = vanilla_model.VanillaSeq2Seq(
    input_vocab_size = len(ddm.input_vocab.idx2char),
    target_vocab_size = len(ddm.target_vocab.idx2char),
    embedding_dim = 256,
    hidden_dim = 512,
    encoder_layers = 2,
    decoder_layers = 2,
    encoder_dropout = 0.0,
    decoder_dropout = 0.0,
    encoding_unit = 'rnn',
    decoding_unit = 'rnn',
    lr = 1e-3,
    optimizer = 'adam',
)

In [11]:
logits = seq_model(batch0['src_input'], batch0['src_len'], batch0['tgt_input'])
logits

tensor([[[-9.7585e-04,  6.1579e-02, -1.2477e-01,  ...,  6.8627e-02,
           7.5931e-02, -9.4983e-02],
         [ 1.0894e-02, -2.6695e-01,  1.7374e-01,  ...,  3.1440e-02,
           5.1228e-02,  2.4525e-01],
         [ 6.9800e-02,  1.9067e-02, -5.2559e-02,  ...,  3.5842e-02,
           1.3402e-01, -1.3506e-01],
         ...,
         [-3.3987e-01, -3.5471e-01, -9.3559e-03,  ..., -1.4440e-02,
           4.7809e-02, -1.1106e-01],
         [-4.8556e-01, -3.2035e-01, -3.3951e-03,  ..., -1.2127e-03,
           9.7428e-02, -1.2021e-01],
         [-5.3128e-01, -2.9397e-01,  2.1655e-02,  ..., -1.5243e-02,
           1.1639e-01, -1.4193e-01]],

        [[-9.8269e-02, -3.2638e-04, -5.2699e-02,  ...,  8.6578e-02,
          -4.7785e-02,  3.6666e-02],
         [-1.7827e-01, -3.0714e-01,  9.4189e-02,  ...,  2.8649e-01,
           6.8893e-02,  4.0426e-01],
         [-1.1548e-01,  1.1968e-01,  5.4504e-02,  ..., -2.8071e-01,
          -1.4401e-02, -8.0366e-02],
         ...,
         [-3.5817e-01, -4

In [12]:
logits.shape

torch.Size([32, 14, 29])

In [33]:
loss, acc, f1 = seq_model._compute_loss_and_metrics(batch0)

ValueError: too many values to unpack (expected 3)

In [13]:
preds = seq_model.predict_step(batch0, 0)

In [14]:
decoded_preds = [ddm.target_vocab.decode(seq) for seq in preds]

In [15]:
ddm.target_vocab.char2idx

{'<pad>': 0,
 '<sos>': 1,
 '<eos>': 2,
 'a': 3,
 'n': 4,
 'k': 5,
 'g': 6,
 'i': 7,
 't': 8,
 'u': 9,
 'c': 10,
 'l': 11,
 'e': 12,
 'r': 13,
 's': 14,
 'h': 15,
 'd': 16,
 'b': 17,
 'y': 18,
 'o': 19,
 'j': 20,
 'z': 21,
 'm': 22,
 'v': 23,
 'w': 24,
 'p': 25,
 'f': 26,
 'x': 27,
 'q': 28}

In [16]:
decoded_preds

['eqk',
 'eqk',
 'eqk',
 'eqk',
 'eqk',
 'eqk',
 'eqk',
 'eqkqwlsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwb',
 'nueqkcwlsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbs',
 'eqk',
 'ifbkqwlsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrw',
 'eqk',
 'eqk',
 'nueqkcwlsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbs',
 'eqk',
 'eqk',
 'eqk',
 'eqk',
 'eqk',
 'eqk',
 'eqk',
 'hbkqwpbkqwpbkqwpbkqwpbkqwpbkqwpbkqwpbkqwpbkqwpbkqw',
 'nueqkcwlsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbs',
 'ifbkqbkqwlsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwb',
 'eqk',
 'ifbkqwlsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrwbsrw',
 'eqk',
 'eqk',
 'eqk',
 'eqk',
 'eqk',
 'eqk']

In [17]:
attn_model = attention_model.AttentionSeq2Seq(
    input_vocab_size = len(ddm.input_vocab.idx2char),
    target_vocab_size = len(ddm.target_vocab.idx2char),
    embedding_dim = 256,
    hidden_dim = 256,
    encoder_layers = 1,
    decoder_layers = 1,
    encoder_dropout = 0.0,
    decoder_dropout = 0.0,
    encoding_unit = 'gru',
    decoding_unit = 'gru',
    max_len = 50,
    beam_width = 5,
    lr = 1e-3,
    optimizer = 'adam',
)

In [18]:
logits, hidden, attn_weights = attn_model(batch0['src_input'], batch0['src_len'], batch0['tgt_input'])

In [19]:
logits.shape

torch.Size([32, 14, 29])

In [20]:
hidden.shape

torch.Size([1, 32, 256])

In [21]:
attn_weights.shape

torch.Size([32, 14, 15])

In [23]:
loss, acc, f1 = attn_model._compute_loss_and_metrics(batch0)

In [24]:
loss

tensor(3.3544, grad_fn=<NllLossBackward0>)

In [25]:
acc

tensor(0.0495)

In [26]:
f1

tensor(0.0219)

In [27]:
preds, attns = attn_model.predict_step(batch0, 0)

In [31]:
decoded_preds = [ddm.target_vocab.decode(seq) for seq in preds]

In [32]:
decoded_preds

['hrkovgg',
 'oraurrrrrrrrrrrrrrd',
 's',
 's',
 's',
 'hrqiz',
 's',
 'qpa',
 'iuutttttqoautttttttqoautttttttqoauttttt',
 'sfffva',
 'hru',
 's',
 'oraurrd',
 'hrururorwur',
 'hrqxgggg',
 'hrdjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjj',
 'ruddjyell',
 's',
 'hra',
 'iuuuttttctttcttcttcttcttcttcttcttcttcttcttcttcttct',
 'qoauuruautttttqoorauuutttttqooruuattt',
 'iuuuttttqoiizuuuuttttqoiizuuuuttttqoiizuu',
 's',
 'rd',
 'ss',
 's',
 'ruutttkooreiuuuttttkoreiuuuuttttkoorauu',
 'hhqpa',
 's',
 'orffffljjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjjj',
 's',
 'iuud']

In [28]:
attns[2].shape[1]

10

In [29]:
batch0['src_len']

tensor([ 8, 12, 10,  8,  7, 10, 14,  7,  6, 12,  7, 10, 15,  7,  9,  9, 11,  8,
        11,  8,  6,  7,  8, 10,  7,  8, 10,  7,  7,  9, 12,  9])

In [30]:
for k in range(len(attns)):
    print(f'attention shape: {attns[k].shape}')
    print(f'input src len: {batch0["src_len"][k]}')
    print()

attention shape: torch.Size([11, 8])
input src len: 8

attention shape: torch.Size([50, 12])
input src len: 12

attention shape: torch.Size([2, 10])
input src len: 10

attention shape: torch.Size([2, 8])
input src len: 8

attention shape: torch.Size([2, 7])
input src len: 7

attention shape: torch.Size([6, 10])
input src len: 10

attention shape: torch.Size([2, 14])
input src len: 14

attention shape: torch.Size([4, 7])
input src len: 7

attention shape: torch.Size([50, 6])
input src len: 6

attention shape: torch.Size([7, 12])
input src len: 12

attention shape: torch.Size([8, 7])
input src len: 7

attention shape: torch.Size([2, 10])
input src len: 10

attention shape: torch.Size([14, 15])
input src len: 15

attention shape: torch.Size([28, 7])
input src len: 7

attention shape: torch.Size([9, 9])
input src len: 9

attention shape: torch.Size([50, 9])
input src len: 9

attention shape: torch.Size([13, 11])
input src len: 11

attention shape: torch.Size([2, 8])
input src len: 8

atten