In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../src")

In [3]:
from fluidvec import *

In [4]:
import pickle
import torch
from torch.utils.data import Dataset

In [5]:
import numpy as np
class TrainDataset:
    def __init__(self, fpath):
        fin = open(fpath, "rb")
        self.data = pickle.load(fin)
        fin.close()
        
    def __len__(self):
        return len(self.data)
    
    def pad_tensor(self, values):
        n = max(len(x) for x in values)
        w = len(values)
        t = np.zeros((w, n))
        for i, v in enumerate(values):
            t[i, :len(v)] = v
        return torch.tensor(t, dtype=torch.int32)
    
    def __getitem__(self, idx):
        tgt, ctx = self.data[idx]
        return {
            "tgt_word": torch.tensor(tgt["word"]).reshape(1, 1),
            "tgt_chars": torch.tensor(tgt["chars"]).reshape(1, -1),
            "tgt_compos": torch.tensor(tgt["compos"]).reshape(1, -1),
            "ctx_word": torch.tensor([x["word"] for x in ctx]).reshape(len(ctx), -1),
            "ctx_chars": self.pad_tensor([x["chars"] for x in ctx]),
            "ctx_compos": self.pad_tensor([x["compos"] for x in ctx])
        }

In [6]:
ds = TrainDataset("../data/train_items/train_items_001.pkl")

In [7]:
torch.tensor([1,2,3]).reshape(1, 3)

tensor([[1, 2, 3]])

In [8]:
ds[12345]

{'tgt_word': tensor([[105]]),
 'tgt_chars': tensor([[187]]),
 'tgt_compos': tensor([[8]]),
 'ctx_word': tensor([[1454],
         [2884],
         [1164],
         [2884]]),
 'ctx_chars': tensor([[1112, 1383],
         [  29, 1679],
         [  20,   75],
         [  29, 1679]], dtype=torch.int32),
 'ctx_compos': tensor([[ 94,  16,  50, 462],
         [  8,  46,  47,   0],
         [ 32,  33,  13, 100],
         [  8,  46,  47,   0]], dtype=torch.int32)}

In [9]:
from torch.utils.data import DataLoader

In [10]:
ds[12345]["ctx_compos"].size(1)

4

In [11]:
from torch.nn.utils.rnn import pad_sequence
def collate_fn(data_list):
    x0 = data_list[0]
    n_batch = len(data_list)
    collated = {}
    
    for k in x0.keys():
        seqs = [x[k].permute(1, 0) for x in data_list]        
        collated[k] = pad_sequence(seqs, batch_first=True, padding_value=1).permute(0, 2, 1)        
    return collated

In [12]:
loader = DataLoader(ds, batch_size=8, collate_fn=collate_fn)

In [13]:
for k, v in next(iter(loader)).items():
    print(k, v.shape)

tgt_word torch.Size([8, 1, 1])
tgt_chars torch.Size([8, 1, 3])
tgt_compos torch.Size([8, 1, 6])
ctx_word torch.Size([8, 4, 1])
ctx_chars torch.Size([8, 4, 3])
ctx_compos torch.Size([8, 4, 6])


In [18]:
ds[30]

{'tgt_word': tensor([[24]]),
 'tgt_chars': tensor([[43, 44]]),
 'tgt_compos': tensor([[28, 62,  8]]),
 'ctx_word': tensor([[ 0],
         [23],
         [25],
         [26]]),
 'ctx_chars': tensor([[ 0,  0],
         [42,  0],
         [45, 46],
         [47, 48]], dtype=torch.int32),
 'ctx_compos': tensor([[ 0,  0,  0,  0],
         [60, 61,  0,  0],
         [63, 64, 65, 66],
         [67, 68, 69, 70]], dtype=torch.int32)}

In [14]:
next(iter(loader))

{'tgt_word': tensor([[[4]],
 
         [[0]],
 
         [[5]],
 
         [[6]],
 
         [[0]],
 
         [[7]],
 
         [[0]],
 
         [[8]]]),
 'tgt_chars': tensor([[[ 2,  3,  1]],
 
         [[ 0,  1,  1]],
 
         [[ 4,  5,  1]],
 
         [[ 6,  7,  1]],
 
         [[ 0,  1,  1]],
 
         [[ 8,  9, 10]],
 
         [[ 0,  1,  1]],
 
         [[11, 12,  1]]]),
 'tgt_compos': tensor([[[ 2,  3,  4,  5,  1,  1]],
 
         [[ 0,  1,  1,  1,  1,  1]],
 
         [[ 6,  7,  8,  1,  1,  1]],
 
         [[ 8,  8,  1,  1,  1,  1]],
 
         [[ 0,  1,  1,  1,  1,  1]],
 
         [[ 9, 10, 11, 12, 13, 14]],
 
         [[ 0,  1,  1,  1,  1,  1]],
 
         [[15, 16, 17, 18,  1,  1]]]),
 'ctx_word': tensor([[[1],
          [1],
          [0],
          [1]],
 
         [[1],
          [4],
          [1],
          [1]],
 
         [[1],
          [1],
          [6],
          [0]],
 
         [[1],
          [5],
          [0],
          [7]],
 
         [[5],
          