In [1]:
import pandas as pd
from pathlib import Path
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from en_indic_transformer import TranslationDataset, TranslationDataLoader, Tokenizer

In [2]:
path = Path()
load_dir = path.absolute().parent

In [3]:
filename = load_dir / 'data/eng_hindi.csv'

In [4]:
filename

PosixPath('/Users/sameergururajmathad/en-indic-transformer/data/eng_hindi.csv')

In [5]:
df = pd.read_csv(filename)

In [6]:
source = df["english_sentence"].tolist()
target = df["hindi_sentence"].tolist()

In [7]:
length = len(source)
length

127705

In [8]:
# tokenizer = tiktoken.get_encoding('gpt2')
tokenizer = Tokenizer()

In [9]:
dataset = TranslationDataset(src=source, target=target, tokenizer=tokenizer, src_prepend_value='<|english|>', target_prepend_value='<|hindi|>')

In [10]:
def custom_collate_fn(batch):
    sources, target_ins, target_outs = [], [], []

    for source, target_in, target_out in batch:
        sources.append(source)
        target_ins.append(target_in)
        target_outs.append(target_out)

    source_padded = pad_sequence(sources, batch_first=True, padding_value=50256)
    target_in_padded = pad_sequence(target_ins, batch_first=True, padding_value=50256)
    target_out_padded = pad_sequence(target_outs, batch_first=True, padding_value=-100)

    return source_padded, target_in_padded, target_out_padded
    

In [11]:
# dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True,collate_fn=custom_collate_fn)
dataloader = TranslationDataLoader(dataset=dataset, batch_size=16, shuffle=True)

In [12]:
data = iter(dataloader)

In [13]:
first = next(data)

In [14]:
# source = list(first[0][2])
# target_in = list(first[1][2])
# target_out = list(first[2][2])

source = first[0][2]
target_in = first[1][2]
target_out = first[2][2]

In [15]:
# tokenizer.decode([id for id in source if id != -100])
tokenizer.decode(source)

'<|english|>even if they were able to take back the country,<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>'

In [16]:
# tokenizer.decode([id for id in target_in if id != -100])
tokenizer.decode(target_in)

'<|hindi|>भले ही वे वापस देश लेने के लिए सक्षम थे,<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><

In [17]:
# tokenizer.decode([id for id in target_out if id != -100])
tokenizer.decode(target_out)

'भले ही वे वापस देश लेने के लिए सक्षम थे,<|endoftext|>'