In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

# HuggingFace datasets and tokenizers
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace

from tdg_dataset import BilingualDataset

from config import get_config

from pathlib import Path

In [3]:
config = get_config()

In [4]:
config

{'batch_size': 8,
 'num_epochs': 20,
 'lr': 0.0001,
 'seq_len': 350,
 'd_model': 512,
 'datasource': 'opus_books',
 'lang_src': 'en',
 'lang_tgt': 'it',
 'model_folder': 'weights',
 'model_basename': 'tmodel_',
 'preload': 'latest',
 'tokenizer_file': 'tokenizer_{0}.json',
 'experiment_name': 'runs/tmodel'}

In [6]:
# load raw dataset
ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')

In [7]:
ds_raw

Dataset({
    features: ['id', 'translation'],
    num_rows: 32332
})

In [16]:
len(ds_raw['id'])

32332

In [17]:
type(ds_raw['translation'])

list

In [18]:
ds_raw['translation'][0]

{'en': 'Source: Project Gutenberg',
 'it': 'Source: www.liberliber.it/Audiobook available here'}

In [13]:

def get_all_sentences(ds, lang):
    for item in ds:
        yield item['translation'][lang]

In [11]:
def get_or_build_tokenizer(config, ds, lang):
    tokenizer_path = Path(config['tokenizer_file'].format(lang))
    print(tokenizer_path)
    if not Path.exists(tokenizer_path):
        # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
        tokenizer.pre_tokenizer = Whitespace()
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
        tokenizer.save(str(tokenizer_path))
    else:
        tokenizer = Tokenizer.from_file(str(tokenizer_path))
    return tokenizer

In [19]:
# Build tokenizers
tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

tokenizer_en.json
tokenizer_it.json


In [21]:
ds_raw

Dataset({
    features: ['id', 'translation'],
    num_rows: 32332
})

In [22]:
train_ds_size = int(0.9 * len(ds_raw))
val_ds_size = len(ds_raw) - train_ds_size
train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

In [23]:
train_ds_size

29098

In [24]:
val_ds_size

3234

In [25]:
config

{'batch_size': 8,
 'num_epochs': 20,
 'lr': 0.0001,
 'seq_len': 350,
 'd_model': 512,
 'datasource': 'opus_books',
 'lang_src': 'en',
 'lang_tgt': 'it',
 'model_folder': 'weights',
 'model_basename': 'tmodel_',
 'preload': 'latest',
 'tokenizer_file': 'tokenizer_{0}.json',
 'experiment_name': 'runs/tmodel'}

In [29]:
for item in ds_raw:
    print(item['translation'][config['lang_src']])
    print(tokenizer_src.encode(item['translation'][config['lang_src']]))
    print(tokenizer_src.encode(item['translation'][config['lang_src']]).tokens)
    print(tokenizer_src.encode(item['translation'][config['lang_src']]).ids)
    break

Source: Project Gutenberg
Encoding(num_tokens=4, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
['Source', ':', 'Project', '[UNK]']
[5781, 38, 7699, 0]


In [30]:
for item in ds_raw:
    print(item['translation'][config['lang_tgt']])
    print(tokenizer_tgt.encode(item['translation'][config['lang_tgt']]))
    print(tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).tokens)
    print(tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids)
    break

Source: www.liberliber.it/Audiobook available here
Encoding(num_tokens=11, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])
['Source', ':', '[UNK]', '.', 'liberliber', '.', 'it', '/', '[UNK]', 'available', 'here']
[8161, 43, 0, 5, 19606, 5, 19516, 10657, 0, 13463, 14295]


In [31]:
# Find the maximum length of each sentence in the source and target sentence
max_len_src = 0
max_len_tgt = 0

for item in ds_raw:
    src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
    tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
    max_len_src = max(max_len_src, len(src_ids))
    max_len_tgt = max(max_len_tgt, len(tgt_ids))

print(f'Max length of source sentence: {max_len_src}')
print(f'Max length of target sentence: {max_len_tgt}')

Max length of source sentence: 309
Max length of target sentence: 274


In [32]:
config

{'batch_size': 8,
 'num_epochs': 20,
 'lr': 0.0001,
 'seq_len': 350,
 'd_model': 512,
 'datasource': 'opus_books',
 'lang_src': 'en',
 'lang_tgt': 'it',
 'model_folder': 'weights',
 'model_basename': 'tmodel_',
 'preload': 'latest',
 'tokenizer_file': 'tokenizer_{0}.json',
 'experiment_name': 'runs/tmodel'}

In [33]:
config['seq_len']

350

In [34]:
config['batch_size']

8

In [35]:
train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])

In [36]:
train_ds

<tdg_dataset.BilingualDataset at 0x1fa50a20f70>

In [37]:
torch.tensor(
            [tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64
        )

tensor([2])

In [38]:
tokenizer_tgt.token_to_id("[SOS]")

2

In [39]:
tokenizer_src.token_to_id("[SOS]")

2

In [40]:
tokenizer_tgt.token_to_id("[EOS]")

3

In [41]:
tokenizer_src.token_to_id("[EOS]")

3

In [42]:
tokenizer_src.token_to_id("[PAD]")

1

In [43]:
tokenizer_src.token_to_id("[UNK]")

0

In [44]:
ds_raw[0]

{'id': '0',
 'translation': {'en': 'Source: Project Gutenberg',
  'it': 'Source: www.liberliber.it/Audiobook available here'}}

In [45]:
ds_raw[100]

{'id': '100',
 'translation': {'en': 'Miss Abbot joined in-- "And you ought not to think yourself on an equality with the Misses Reed and Master Reed, because Missis kindly allows you to be brought up with them.',
  'it': 'La signorina Abbot soggiunse: — Spero che non vi crederete eguale alle signorine e al signor Reed, perché la signora è così buona da farvi educare insieme con loro.'}}

In [47]:
idx = 100
seg_len = config['seq_len']
src_target_pair = ds_raw[idx]


In [48]:
src_target_pair

{'id': '100',
 'translation': {'en': 'Miss Abbot joined in-- "And you ought not to think yourself on an equality with the Misses Reed and Master Reed, because Missis kindly allows you to be brought up with them.',
  'it': 'La signorina Abbot soggiunse: — Spero che non vi crederete eguale alle signorine e al signor Reed, perché la signora è così buona da farvi educare insieme con loro.'}}

In [49]:
src_text = src_target_pair["translation"][config['lang_src']]
tgt_text = src_target_pair["translation"][config['lang_tgt']]

# Transform the text into tokens
enc_input_tokens = tokenizer_src.encode(src_text).ids
dec_input_tokens = tokenizer_tgt.encode(tgt_text).ids

# Add SOS, EOS and padding to each sentence
enc_num_padding_token = (
    seg_len - len(enc_input_tokens) - 2
)  # We will add  <SOS> and <EOS> tokens
dec_num_padding_token = (
    seg_len - len(dec_input_tokens) - 1
)  # We will add <EOS> toke

In [55]:
len(dec_input_tokens)

32

In [51]:
tgt_text

'La signorina Abbot soggiunse: — Spero che non vi crederete eguale alle signorine e al signor Reed, perché la signora è così buona da farvi educare insieme con loro.'

In [56]:
sos_token = torch.tensor(
    [tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64
)
eos_token = torch.tensor(
    [tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64
)
pad_token = torch.tensor(
    [tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64
)

In [57]:
print(sos_token, eos_token, pad_token)

tensor([2]) tensor([3]) tensor([1])


In [58]:
enc_input_tokens

[317,
 2570,
 1535,
 13,
 78,
 35,
 91,
 24,
 388,
 21,
 8,
 153,
 544,
 36,
 65,
 8008,
 22,
 5,
 5762,
 698,
 6,
 3443,
 698,
 4,
 173,
 3324,
 1536,
 8843,
 24,
 8,
 37,
 291,
 56,
 22,
 50,
 7]

In [59]:
encoder_input = torch.cat(
        [
            sos_token,
            torch.tensor(enc_input_tokens, dtype=torch.int64),
            eos_token,
            torch.tensor(
                [pad_token] * enc_num_padding_token, dtype=torch.int64
            ),
        ],
        dim=0,
    )

In [60]:
encoder_input

tensor([   2,  317, 2570, 1535,   13,   78,   35,   91,   24,  388,   21,    8,
         153,  544,   36,   65, 8008,   22,    5, 5762,  698,    6, 3443,  698,
           4,  173, 3324, 1536, 8843,   24,    8,   37,  291,   56,   22,   50,
           7,    3,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,   

In [61]:
decoder_input = torch.cat(
            [
                sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor(
                    [pad_token] * dec_num_padding_token, dtype=torch.int64
                ),
            ],
            dim=0,
        )

In [62]:
decoder_input

tensor([    2,    79,   357,  2496,   911,    43,     9,  2274,     8,    12,
           64, 18158,  2365,   162,  2568,     6,    39,   175,   746,     4,
           70,    11,   209,    27,    55,   400,    28,  2527,  7550,   347,
           20,    69,     5,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1, 

In [63]:
label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                eos_token,
                torch.tensor(
                    [pad_token] * dec_num_padding_token, dtype=torch.int64
                ),
            ],
            dim=0,
        )

In [64]:
label

tensor([   79,   357,  2496,   911,    43,     9,  2274,     8,    12,    64,
        18158,  2365,   162,  2568,     6,    39,   175,   746,     4,    70,
           11,   209,    27,    55,   400,    28,  2527,  7550,   347,    20,
           69,     5,     3,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1, 

In [66]:
def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

In [80]:
output = {
            "encoder_input": encoder_input,  # Shape: (seq_len,)
            "decoder_input": decoder_input,  # Shape: (seq_len,)
            "encoder_mask": (encoder_input != pad_token)
            .unsqueeze(0)
            .unsqueeze(0)
            .int(),  # Shape: (1, 1, seq_len)
            "decoder_mask": (decoder_input != pad_token).unsqueeze(0).int()
            & causal_mask(
                decoder_input.shape[0]
            ),  # Shape: (1, seq_len) &  (1, seq_len, seq_len)
            "label": label,  # Shape: (seq_len,)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }

In [92]:
label.shape

torch.Size([350])

In [70]:
pad_token

tensor([1])

In [72]:
(encoder_input != pad_token).shape

torch.Size([350])

In [74]:
(encoder_input != pad_token)

tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, 

In [73]:
(encoder_input != pad_token).unsqueeze(0).unsqueeze(0).int()

tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [75]:
(encoder_input != pad_token).unsqueeze(0).unsqueeze(0).int().shape

torch.Size([1, 1, 350])

In [76]:
(decoder_input != pad_token).unsqueeze(0).int()

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0

In [77]:
decoder_input.shape[0]

350

In [78]:
causal_mask(decoder_input.shape[0])

tensor([[[ True, False, False,  ..., False, False, False],
         [ True,  True, False,  ..., False, False, False],
         [ True,  True,  True,  ..., False, False, False],
         ...,
         [ True,  True,  True,  ...,  True, False, False],
         [ True,  True,  True,  ...,  True,  True, False],
         [ True,  True,  True,  ...,  True,  True,  True]]])

In [81]:
output['decoder_mask']

tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]], dtype=torch.int32)

In [82]:
len(output['decoder_mask'])

1

In [83]:
import torch

def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0

# Let's say we have the following decoder input and pad token
decoder_input = torch.tensor([1, 2, 3, 0, 0])  # 0 is the pad token
pad_token = 0

# Create a mask where True represents non-pad tokens
non_pad_mask = (decoder_input != pad_token).unsqueeze(0).int()

# Create a causal mask for the decoder input
causal_mask = causal_mask(decoder_input.shape[0])

# Combine the two masks
combined_mask = non_pad_mask & causal_mask

print(combined_mask)

tensor([[[1, 0, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [1, 1, 1, 0, 0],
         [1, 1, 1, 0, 0],
         [1, 1, 1, 0, 0]]], dtype=torch.int32)


In [84]:
non_pad_mask

tensor([[1, 1, 1, 0, 0]], dtype=torch.int32)

In [85]:
causal_mask

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]])

In [86]:
combined_mask

tensor([[[1, 0, 0, 0, 0],
         [1, 1, 0, 0, 0],
         [1, 1, 1, 0, 0],
         [1, 1, 1, 0, 0],
         [1, 1, 1, 0, 0]]], dtype=torch.int32)

In [87]:
output['encoder_mask']

tensor([[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [88]:
output['decoder_mask']

tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 1, 0,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]], dtype=torch.int32)

In [89]:
print(output['encoder_mask'].shape)

torch.Size([1, 1, 350])


In [90]:
print(output['decoder_mask'].shape)

torch.Size([1, 350, 350])


In [93]:
train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
val_dataloader = DataLoader(val_ds, batch_size=config['batch_size'], shuffle=True)

In [94]:
train_dataloader

<torch.utils.data.dataloader.DataLoader at 0x1fa50aa0a60>

In [95]:
val_dataloader

<torch.utils.data.dataloader.DataLoader at 0x1fa50aa33d0>

In [96]:
train_ds

<tdg_dataset.BilingualDataset at 0x1fa50a20f70>