In [1]:
# https://www.kaggle.com/datasets/dhruvildave/en-fr-translation-dataset

from typing import NamedTuple, Generator
import pandas as pd
import torch
import torch.nn.functional as F
from torch import Tensor, optim, nn
from tqdm import tqdm
from transformers import AutoTokenizer
from src.transformer import Transformer
from transformers import get_linear_schedule_with_warmup
from collections import deque


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
class PairedSentences(NamedTuple):
    fr: str
    en: str

class ListPairedSentences(NamedTuple):
    fr: list[str]
    en: list[str]

    def __getitem__(self, index: int) -> PairedSentences:
        return PairedSentences(self.fr[index], self.en[index])

class TrainingBatch(NamedTuple):
    input_ids: Tensor
    """In our case french"""
    
    encoder_mask: Tensor
    """basically padding tokens"""
    
    output_ids: Tensor
    """In our case english"""
    
    decoder_mask: Tensor
    """Padding tokens, don't forget to add causal mask during training"""

    def __repr__(self):
        return f"TrainingBatch(x.shape={self.input_ids.shape}, y.shape={self.output_ids.shape})"

In [3]:
class RollingAverage:
    def __init__(self, window_size=1000):
        self.window_size = window_size
        self.losses = deque(maxlen=window_size)
        self.sum = 0.0
    
    def add(self, loss_value):
        """Add a new loss value"""
        # Convert tensor to float if needed
        if torch.is_tensor(loss_value):
            loss_value = loss_value.item()
        
        # If we're at capacity, subtract the value that will be removed
        if len(self.losses) == self.window_size:
            self.sum -= self.losses[0]
        
        # Add new value
        self.losses.append(loss_value)
        self.sum += loss_value
    
    def avg(self, last_n=None):
        """Get average of last n values (or all if n is None)"""
        if not self.losses:
            return 0.0
        
        if last_n is None:
            return self.sum / len(self.losses)
        
        # Get last n values
        n = min(last_n, len(self.losses))
        last_values = list(self.losses)[-n:]
        return sum(last_values) / n
    
    def __len__(self):
        return len(self.losses)

In [4]:
class Processor:
    def __init__(self, sequence_length: int, tokenizer_name: str) -> None:
        self._tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self._seq_length = sequence_length
        
    @property
    def tokenizer(self) -> AutoTokenizer:
        return self._tokenizer

    @property
    def sequence_length(self) -> int:
        return self._seq_length

    @property
    def vocab_size(self) -> int:
        return self._tokenizer.vocab_size

    def tokenize(self, text: str, padding: str="max_length", truncation:bool=True, extra: int = 0):
        return self._tokenizer(
            text,
            return_tensors="pt",
            max_length=self._seq_length + extra,
            padding=padding,
            truncation=truncation)

    def decode(self, token_ids: Tensor, **kwargs) -> str:
        return self._tokenizer.decode(token_ids, **kwargs)

    def make_batch(self, paired_sentences: ListPairedSentences, dtype=torch.float32) -> TrainingBatch:
        # Tokenize each sentence in the 'fr' and 'en' lists
        fr_sentences = [self.tokenize(sentence) for sentence in paired_sentences.fr]
        en_sentences = [self.tokenize(sentence, extra=1) for sentence in paired_sentences.en]

        # Stack tokenized tensors for batching
        X_batch = torch.stack([x['input_ids'].squeeze(0) for x in fr_sentences])
        Y_batch = torch.stack([y['input_ids'].squeeze(0) for y in en_sentences])

        # Create encoder and decoder padding mask: 1 for real tokens, 0 for padding
        encoder_mask = torch.stack([x['attention_mask'].squeeze(0) for x in fr_sentences]) \
            .unsqueeze(1).unsqueeze(2)
        decoder_mask = torch.stack([y['attention_mask'].squeeze(0) for y in en_sentences]) \
            .unsqueeze(1).unsqueeze(2)

        return TrainingBatch(
            input_ids=X_batch,
            output_ids=Y_batch,
            encoder_mask=encoder_mask.to(dtype),
            decoder_mask=decoder_mask.to(dtype))

def get_first_masked_token(mask: torch.Tensor) -> Tensor:
    squeezed_mask = mask.squeeze(1).squeeze(1) # mask is shaped (bs, 1, 1, sequence_length)
    first_masked_indices = (squeezed_mask == 0).int().argmax(dim=1)
    first_masked_indices[squeezed_mask.sum(dim=1) == squeezed_mask.size(1)] = squeezed_mask.size(1)
    return first_masked_indices.to(dtype=torch.int32)

def mask_last_token(current_mask: torch.Tensor) -> torch.Tensor:
    first_masked_indices = get_first_masked_token(current_mask) # get the index of first masked token
    last_token_indices = torch.clamp(first_masked_indices - 1, min=0) # to avoid negative indices
    current_mask[torch.arange(current_mask.size(0)), 0, 0, last_token_indices] = 0 # set the last 1 token to 0
    return current_mask

In [5]:
def get_page(csv_path: str, page: int, rows_per_page: int):
    return pd.read_csv(csv_path, skiprows = 1 + page * rows_per_page, nrows=rows_per_page, header=None, names=["en", "fr"])

def make_generator(csv_path: str, rows_per_page: int) -> Generator[ListPairedSentences, None, None]:
    i = 0
    while True:
        page = get_page(csv_path, i, rows_per_page)
        fr_sentences = page["fr"].to_list()
        en_sentences = page["en"].to_list()
        yield ListPairedSentences(fr_sentences, en_sentences)
        i += 1
        
def make_generator_v2(csv_path: str, rows_per_page: int):
    for chunk in pd.read_csv(csv_path, chunksize=rows_per_page):
        fr_sentences = chunk["fr"].to_list()
        en_sentences = chunk["en"].to_list()
        yield ListPairedSentences(fr_sentences, en_sentences)

In [6]:
dtype=torch.bfloat16
csv_path = "archive/en-fr.csv"
processor = Processor(200, "bert-base-uncased")
d_model = 512
model = Transformer(vocab_size=processor.vocab_size, max_sequence_len=processor.sequence_length, d_model=d_model).to(DEVICE).to(dtype)

# opt and loss
learning_rate = 3e-4
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss(ignore_index=processor.tokenizer.pad_token_id).to(DEVICE)

# training loop qty
num_epochs = 1
batch_size = 48
break_at = None # Will never break
# get_num_steps(csv_path, batch_size) running this function is too long ... just use the cached value
num_steps = 22520376 // batch_size

num_training_steps = num_epochs * num_steps
num_warmup_steps = 10_000

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)
loss_tracker = RollingAverage(window_size=1000)

In [7]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    model.to(DEVICE)
    epoch_loss = 0.0
    loop_error = 0.0

    # Reset generator at the start of each epoch
    progress_bar = tqdm(make_generator_v2(csv_path, batch_size), total=num_steps)
    progress_bar.set_description(f"Epoch {epoch + 1}")

    for step, raw_batch in enumerate(progress_bar, start=1):
        try:
            training_batch = processor.make_batch(raw_batch, dtype)        # Converts batch to `TrainingBatch` format
            input_ids = training_batch.input_ids.to(DEVICE)            # Encoder input (source sentence)
            target_ids = training_batch.output_ids.to(DEVICE)          # Decoder target sequence
            encoder_mask = training_batch.encoder_mask.to(DEVICE)   # Mask for encoder

            # adjust decoder input
            decoder_input_ids = target_ids[:, :-1]
            target_ids_flat = target_ids[:, 1:].contiguous().view(-1)

            # adjust decoder mask too
            decoder_mask = training_batch.decoder_mask[:, :, :, :-1].to(DEVICE)
            
            # We need to create a causal mask too
            seq_len = decoder_input_ids.shape[1]
            causal_mask = torch.tril(torch.ones((seq_len, seq_len))).to(DEVICE).to(dtype)
            
            # final decoder mask as prod of padding * causal
            final_decoder_mask = decoder_mask * causal_mask.unsqueeze(0)

        except Exception as e:
            print(f"Error in batch {step}: {e}")
            loop_error += 1
            continue

        # Forward pass
        optimizer.zero_grad()
        output_probs = model(
            input_ids,
            decoder_input_ids,
            encoder_mask=encoder_mask,
            decoder_mask=final_decoder_mask)

        # flatten target and outputprobs to compute cce loss
        output_probs_flat = output_probs.view(-1, output_probs.size(-1))

        # Calculate the loss
        loss = loss_fn(output_probs_flat, target_ids_flat)
        loss.backward()
        optimizer.step()
        scheduler.step()

        # Track loss
        loss_tracker.add(loss.item())
        epoch_loss += loss.item()

        # Optionally, print progress
        progress_bar.set_postfix_str(f"epoch loss : {epoch_loss / step:.4f} | "
                                     f"last {loss_tracker.window_size} steps loss: {loss_tracker.avg()} | "
                                     f"loop error : {loop_error} | "
                                     f"lr : {scheduler.get_lr()}")
        
        
        if step == break_at: # train only on a subset for now
            break

    # with torch.no_grad():
    #     sample_output = output_probs.argmax(dim=-1)
    #     print("Predicted tokens:", sample_output[0, :10].squeeze())  
    #     print("Target tokens:   ", target_ids[0, 1:11].squeeze())  

    # Print average loss per epoch
    # print(f"Epoch [{epoch+1}/{num_epochs}] completed, Average Loss: {epoch_loss / num_steps:.4f}")

Epoch 1:   0%|          | 60/469174 [00:08<14:19:41,  9.09it/s, epoch loss : 10.3271 | last 1000 steps loss: 10.502118644067796 | loop error : 1.0 | lr : [1.7699999999999998e-06]]

Error in batch 59: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   0%|          | 1038/469174 [02:27<14:12:15,  9.15it/s, epoch loss : 9.2600 | last 1000 steps loss: 9.23346875 | loop error : 2.0 | lr : [3.1079999999999994e-05]]        

Error in batch 1037: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   1%|          | 2667/469174 [06:21<14:36:57,  8.87it/s, epoch loss : 7.7967 | last 1000 steps loss: 6.44809375 | loop error : 3.0 | lr : [7.992000000000001e-05]] 

Error in batch 2666: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   1%|▏         | 6074/469174 [14:28<14:43:13,  8.74it/s, epoch loss : 6.7251 | last 1000 steps loss: 5.49184375 | loop error : 4.0 | lr : [0.00018209999999999998]] 

Error in batch 6073: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   1%|▏         | 6400/469174 [15:15<13:59:34,  9.19it/s, epoch loss : 6.6573 | last 1000 steps loss: 5.373828125 | loop error : 5.0 | lr : [0.00019184999999999997]]

Error in batch 6399: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   1%|▏         | 6439/469174 [15:20<14:05:48,  9.12it/s, epoch loss : 6.6573 | last 1000 steps loss: 5.414390625 | loop error : 6.0 | lr : [0.00019298999999999998]]

Error in batch 6438: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   1%|▏         | 6446/469174 [15:21<11:13:31, 11.45it/s, epoch loss : 6.6546 | last 1000 steps loss: 5.416171875 | loop error : 8.0 | lr : [0.00019313999999999998]]

Error in batch 6444: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).
Error in batch 6445: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   1%|▏         | 6528/469174 [15:33<14:03:25,  9.14it/s, epoch loss : 6.6473 | last 1000 steps loss: 5.491828125 | loop error : 9.0 | lr : [0.00019557]]            

Error in batch 6527: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   1%|▏         | 6756/469174 [16:05<14:09:57,  9.07it/s, epoch loss : 6.6064 | last 1000 steps loss: 5.54071875 | loop error : 10.0 | lr : [0.00020237999999999997]]

Error in batch 6755: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   2%|▏         | 7795/469174 [18:35<14:40:20,  8.73it/s, epoch loss : 6.3901 | last 1000 steps loss: 4.986578125 | loop error : 11.0 | lr : [0.00023351999999999997]]

Error in batch 7794: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   2%|▏         | 7807/469174 [18:36<13:59:23,  9.16it/s, epoch loss : 6.3867 | last 1000 steps loss: 4.98290625 | loop error : 12.0 | lr : [0.00023384999999999997]] 

Error in batch 7806: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   2%|▏         | 8525/469174 [20:20<14:15:40,  8.97it/s, epoch loss : 6.2945 | last 1000 steps loss: 5.16434375 | loop error : 13.0 | lr : [0.00025535999999999994]] 

Error in batch 8524: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   2%|▏         | 11519/469174 [27:31<14:06:22,  9.01it/s, epoch loss : 5.9613 | last 1000 steps loss: 4.855953125 | loop error : 14.0 | lr : [0.0002990167126187458]] 

Error in batch 11518: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   3%|▎         | 15810/469174 [37:48<14:02:19,  8.97it/s, epoch loss : 5.5930 | last 1000 steps loss: 4.769203125 | loop error : 15.0 | lr : [0.00029621385357184854]] 

Error in batch 15809: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   3%|▎         | 15815/469174 [37:49<13:33:01,  9.29it/s, epoch loss : 5.5921 | last 1000 steps loss: 4.766734375 | loop error : 16.0 | lr : [0.00029621124018345987]]

Error in batch 15814: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   4%|▍         | 17738/469174 [42:24<13:26:28,  9.33it/s, epoch loss : 5.4325 | last 1000 steps loss: 3.83525 | loop error : 17.0 | lr : [0.0002949555070626821]]     

Error in batch 17737: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   4%|▍         | 17776/469174 [42:30<14:08:23,  8.87it/s, epoch loss : 5.4290 | last 1000 steps loss: 3.82659375 | loop error : 18.0 | lr : [0.00029493133322008646]] 

Error in batch 17775: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   4%|▍         | 18229/469174 [43:34<11:34:36, 10.82it/s, epoch loss : 5.3924 | last 1000 steps loss: 3.927421875 | loop error : 20.0 | lr : [0.00029463667367925884]]

Error in batch 18227: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).
Error in batch 18228: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   4%|▍         | 19464/469174 [46:32<13:44:04,  9.10it/s, epoch loss : 5.2774 | last 1000 steps loss: 3.522390625 | loop error : 21.0 | lr : [0.0002938304433613401]] 

Error in batch 19463: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▍         | 22759/469174 [54:27<13:57:50,  8.88it/s, epoch loss : 5.1063 | last 1000 steps loss: 4.18596875 | loop error : 22.0 | lr : [0.000291678318023233]]    

Error in batch 22758: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▍         | 22768/469174 [54:28<14:06:15,  8.79it/s, epoch loss : 5.1056 | last 1000 steps loss: 4.18396875 | loop error : 23.0 | lr : [0.0002916730912464556]]  

Error in batch 22767: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▍         | 23210/469174 [55:32<13:34:03,  9.13it/s, epoch loss : 5.0849 | last 1000 steps loss: 4.1025 | loop error : 24.0 | lr : [0.00029138496517659967]]     

Error in batch 23209: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).
Error in batch 23211: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▌         | 24012/469174 [57:31<13:45:42,  8.99it/s, epoch loss : 5.0545 | last 1000 steps loss: 4.15996875 | loop error : 26.0 | lr : [0.0002908622874988566]]  

Error in batch 24011: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▌         | 24270/469174 [58:08<13:41:15,  9.03it/s, epoch loss : 5.0408 | last 1000 steps loss: 4.06459375 | loop error : 27.0 | lr : [0.0002906943772948817]]  

Error in batch 24269: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▌         | 24291/469174 [58:11<13:22:56,  9.23it/s, epoch loss : 5.0404 | last 1000 steps loss: 4.07509375 | loop error : 28.0 | lr : [0.0002906813103529381]]  

Error in batch 24290: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▌         | 24325/469174 [58:15<13:23:15,  9.23it/s, epoch loss : 5.0403 | last 1000 steps loss: 4.080078125 | loop error : 29.0 | lr : [0.0002906597498987312]]

Error in batch 24324: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▌         | 24342/469174 [58:17<13:35:50,  9.09it/s, epoch loss : 5.0396 | last 1000 steps loss: 4.070765625 | loop error : 30.0 | lr : [0.0002906492963451763]]

Error in batch 24341: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▌         | 24359/469174 [58:20<13:45:49,  8.98it/s, epoch loss : 5.0391 | last 1000 steps loss: 4.0708125 | loop error : 31.0 | lr : [0.00029063884279162145]]  

Error in batch 24358: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▌         | 24372/469174 [58:21<13:44:55,  8.99it/s, epoch loss : 5.0386 | last 1000 steps loss: 4.069984375 | loop error : 32.0 | lr : [0.0002906310026264553]] 

Error in batch 24371: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▌         | 24429/469174 [58:29<13:28:37,  9.17it/s, epoch loss : 5.0381 | last 1000 steps loss: 4.071890625 | loop error : 33.0 | lr : [0.0002905944151890133]] 

Error in batch 24428: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   5%|▌         | 24842/469174 [59:29<18:26:05,  6.70it/s, epoch loss : 5.0185 | last 1000 steps loss: 4.029265625 | loop error : 34.0 | lr : [0.0002903252361849756]] 

Error in batch 24841: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   6%|▌         | 26812/469174 [1:04:15<13:32:34,  9.07it/s, epoch loss : 4.8976 | last 1000 steps loss: 3.3587109375 | loop error : 35.0 | lr : [0.00028903879575063046]]

Error in batch 26811: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:   9%|▉         | 43692/469174 [1:44:23<13:14:18,  8.93it/s, epoch loss : 4.4112 | last 1000 steps loss: 4.507796875 | loop error : 36.0 | lr : [0.0002780109500973487]]  

Error in batch 43691: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:  10%|▉         | 44942/469174 [1:47:21<12:23:48,  9.51it/s, epoch loss : 4.3937 | last 1000 steps loss: 3.752875 | loop error : 37.0 | lr : [0.0002771949195729723]]    

Error in batch 44941: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:  10%|█         | 48774/469174 [1:56:35<13:02:23,  8.96it/s, epoch loss : 4.3726 | last 1000 steps loss: 4.05690625 | loop error : 38.0 | lr : [0.00027469194684368015]] 

Error in batch 48773: text input must be of type `str` (single example), `list[str]` (batch or single pretokenized example) or `list[list[str]]` (batch of pretokenized examples).


Epoch 1:  11%|█▏        | 52849/469174 [2:06:19<16:35:09,  6.97it/s, epoch loss : 4.3253 | last 1000 steps loss: 4.2942890625 | loop error : 38.0 | lr : [0.00027202955742267634]]


KeyboardInterrupt: 

In [8]:
@torch.no_grad()
def infer(model: nn.Module, 
          processor: Processor, 
          french_sentence: str, 
          max_length: int | None = None, 
          skip_special_tokens: bool = True,
          dtype = torch.float32) -> str:
    # first model as eval (we don't train here)
    model.eval()
    sequence_length = max_length if max_length else processor.sequence_length

    # Tokenize the input sequence in french
    tokens = processor.tokenize(french_sentence)

    # create the encoder input and mask
    encoder_ids = tokens["input_ids"].int().to(DEVICE)
    encoder_mask = tokens["attention_mask"].unsqueeze(0).to(DEVICE)
    
    # Create padding mask for current sequence length
    padding_mask = torch.zeros((1, 1, 1, sequence_length), dtype=dtype).to(DEVICE)
    causal_mask = (
        torch.tril(torch.ones((sequence_length, sequence_length), dtype=dtype))
        .unsqueeze(0)
        .to(DEVICE)
    )

    # Initialize decoder input with just the start token
    generated_ids = torch.zeros((1, sequence_length), dtype=torch.int).to(DEVICE)
    generated_ids[0, 0] = processor.tokenizer.cls_token_id
    
    # loop to generate output ids
    for idx in range(1, sequence_length):
        padding_mask[0, 0, :idx, :idx + 1] = 1.0  # All positions up to current length are unmasked
        final_mask = padding_mask * causal_mask
        # return
        output_probs = model(
            encoder_ids,
            generated_ids,
            encoder_mask=encoder_mask,
            decoder_mask=final_mask)

        # Select next token from the LAST position (current_len - 1)
        next_token_id = output_probs[0, idx-1, :].argmax(dim=-1)  # Last position, all vocab
        
        # Append the new token
        generated_ids[0, idx] = next_token_id

        # early stop when encounter sep_token_id
        if next_token_id.item() == processor.tokenizer.sep_token_id:
            break

    return processor.decode(generated_ids[0], skip_special_tokens=skip_special_tokens)

In [9]:
sentences = make_generator(csv_path, 1)

In [62]:
sentence = next(sentences)
print(sentence.fr) 
print(sentence.en)
print(infer(model, processor, sentence.fr[0], dtype=dtype))

["Observatoires Depuis des milliers d'années, les autochtones observent les étoiles pour se repérer dans l'espace et le temps."]
['Observatories For thousands of years, native people use the stars to navigate and to monitor the passage of time.']
since the last year, the aboriginal peoples has been able to provide a variety of information on the energy and energy efficiency of the environment.


In [66]:
sentence = input("Any sentence: ")
print(infer(model, processor, sentence, dtype=dtype))

hsi
