In [1]:
import pandas as pd
from pathlib import Path
import tiktoken
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from en_indic_transformer import TranslationDataset

In [2]:
path = Path()
load_dir = path.absolute().parent

In [3]:
filename = load_dir / 'data/eng_hindi.csv'

In [4]:
filename

PosixPath('/Users/sameergururajmathad/en-indic-transformer/data/eng_hindi.csv')

In [5]:
df = pd.read_csv(filename)

In [6]:
source = df["english_sentence"].tolist()
target = df["hindi_sentence"].tolist()

In [7]:
length = len(source)
length

127705

In [8]:
tokenizer = tiktoken.get_encoding('gpt2')

In [9]:
dataset = TranslationDataset(src=source, target=target, tokenizer=tokenizer, src_prepend_value='<|english|>', target_prepend_value='<|hindi|>')

In [30]:
def custom_collate_fn(batch):
    sources, target_ins, target_outs = [], [], []

    for source, target_in, target_out in batch:
        sources.append(source)
        target_ins.append(target_in)
        target_outs.append(target_out)

    source_padded = pad_sequence(sources, batch_first=True, padding_value=50256)
    target_in_padded = pad_sequence(target_ins, batch_first=True, padding_value=50256)
    target_out_padded = pad_sequence(target_outs, batch_first=True, padding_value=-100)

    return source_padded, target_in_padded, target_out_padded
    

In [31]:
dataloader = DataLoader(dataset=dataset, batch_size=16, shuffle=True,collate_fn=custom_collate_fn)

In [32]:
data = iter(dataloader)

In [33]:
first = next(data)

In [38]:
source = list(first[0][2])
target_in = list(first[1][2])
target_out = list(first[2][2])

In [39]:
tokenizer.decode([id for id in source if id != -100])

"<|english|>so there's holes in the opposite corners, there's a little hole over here.<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>"

In [40]:
tokenizer.decode([id for id in target_in if id != -100])

'<|hindi|>तो इसमे उल्टे कोनों पर छेद हैं, और यहाँ छोटा सा छेद है।<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|

In [41]:
tokenizer.decode([id for id in target_out if id != -100])

'तो इसमे उल्टे कोनों पर छेद हैं, और यहाँ छोटा सा छेद है।<|endoftext|>'