In [2]:
import sys
sys.path.append('..')

import torch
import argparse
import numpy as np
from tqdm import tqdm
import pytorch_lightning as pl
from transformers import AutoTokenizer
from transformer_pl import TransformerPL
from translate.storage.tmx import tmxfile
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    save_top_k=1,
    monitor="train_loss",
    mode="min",
    dirpath="lightning_logs/",
    filename="voiceformer-{epoch:02d}-{train_loss:.5f}",
    save_on_train_epoch_end = True
)

def EnIt_collate(batch): 
    src = []; tgt = []
    for item in batch:
        src.append(item[0])
        tgt.append(item[1])
    return torch.cat(src,dim=0), torch.cat(tgt,dim=0)

class EnIt(Dataset):
    def __init__(self, corpus_dir, split, split_val=.1, reduction=0):
        with open(corpus_dir, 'rb') as fin:
            tmx_file = tmxfile(fin, 'en', 'ar')

        corpus = list(tmx_file.unit_iter())
        if reduction:
            corpus = corpus[:reduction]
        
        # split train/val     
        if split == "train":
            corpus = corpus[:int(len(corpus)*(1-split_val))]
        else: 
            corpus = corpus[int(len(corpus)*(1-split_val)):]

        self.tokenizer_en = AutoTokenizer.from_pretrained("bert-base-cased")
        self.tokenizer_it = AutoTokenizer.from_pretrained("dbmdz/bert-base-italian-cased")
        
        print('\ncompute tokens ids for src')
        src_tokens = [self.tokenizer_en(doc.source)['input_ids'] for doc in tqdm(corpus)]
        print('compute tokens ids for tgt\n')
        tgt_tokens = [self.tokenizer_it(doc.target)['input_ids'] for doc in tqdm(corpus)]
        # filter for samples with length <= 512 tokens
        src_mask = np.where(np.array([512 <= len(tokens) for tokens in src_tokens]).astype(int)*-1+1)[0]
        tgt_mask = np.where(np.array([512 <= len(tokens) for tokens in tgt_tokens]).astype(int)*-1+1)[0]
        mask = np.array(list(set(src_mask) & set(tgt_mask)))
        src_tokens = list(np.array(src_tokens)[mask])
        tgt_tokens = list(np.array(tgt_tokens)[mask])
        assert len(src_tokens) == len(tgt_tokens)
        print(f"{len(src_tokens)} samples in {split} set")

        self.corpus = {
            'src': src_tokens,
            'tgt': tgt_tokens
        }
        self.max_tgt_len = self.max_src_len = 512
        self.pad_tgt_value = self.tokenizer_it.pad_token_id
        self.pad_src_value = self.tokenizer_en.pad_token_id
    
    def __len__(self):
        return len(self.corpus['src'])

    def __getitem__(self, idx):
        src, tgt = torch.tensor(self.corpus['src'][idx]), torch.tensor(self.corpus['tgt'][idx])
        # pad src
        if src.size(-1) < self.max_src_len:
            src = torch.nn.ConstantPad1d((0, self.max_src_len - src.size(-1)), self.pad_src_value)(src)
        # eos tgt
        if tgt.size(-1) < self.max_tgt_len:
            tgt = torch.nn.ConstantPad1d((0, self.max_tgt_len - tgt.size(-1)), self.pad_tgt_value)(tgt)
        
        return src.unsqueeze(0), tgt.unsqueeze(0)

In [3]:
training_data = EnIt(corpus_dir="data/en-it.tmx", split="train", reduction=1000)
train_dataloader = DataLoader(training_data, 
                                           batch_size = 4, 
                                           drop_last = True,
                                           shuffle=True,
                                           collate_fn=EnIt_collate)


compute tokens ids for src


100%|██████████| 900/900 [00:00<00:00, 11259.31it/s]


compute tokens ids for tgt



100%|██████████| 900/900 [00:00<00:00, 10743.73it/s]

900 samples in train set



  src_tokens = list(np.array(src_tokens)[mask])
  tgt_tokens = list(np.array(tgt_tokens)[mask])


In [5]:
src,tgt = next(iter(train_dataloader))

In [6]:
src

tensor([[  101,  2408,  8708,  ...,     0,     0,     0],
        [  101,  5316,   119,  ...,     0,     0,     0],
        [  101,  1109, 11336,  ...,     0,     0,     0],
        [  101,  1109,  3442,  ...,     0,     0,     0]])

In [7]:
tgt

tensor([[ 102, 4445, 4445,  ...,    0,    0,    0],
        [ 102, 2287,  697,  ...,    0,    0,    0],
        [ 102,  329,  533,  ...,    0,    0,    0],
        [ 102,  966, 3935,  ...,    0,    0,    0]])

In [11]:
training_data.tokenizer_en.pad_token_id

0