In [1]:
import torch
from transformers import MBart50TokenizerFast
import pandas as pd
import torch.optim as optim
from utils.rlst import RLST, train_step, translate, TranslationDataset
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer.src_lang = "en_XX"
tokenizer.tgt_lang = "pl_PL"

chunk_size = 100_000

file_reader = pd.read_csv('data/en_pl.csv', chunksize=chunk_size)
first_chunk = next(file_reader) 
print(type(first_chunk))
print(first_chunk.head())

train_data, test_data = train_test_split(first_chunk, test_size=0.1)

train_dataset = TranslationDataset(train_data, tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=256, shuffle=False)

test_dataset = TranslationDataset(test_data, tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=256, shuffle=True)



<class 'pandas.core.frame.DataFrame'>
                                english  \
0      Previously on "The Blacklist"...   
1        - You want to call your daddy?   
2  - Yeah, I want to tell him I'm okay.   
3                                 Okay.   
4  Lizzy... Be careful of your husband.   

                                   polish  
0             /W poprzednich odcinkach: /  
1             - Chcesz zadzwonić do taty?  
2  - Tak, powiem, że wszystko w porządku.  
3                                 Dobrze.  
4                  Lizzy, uważaj na męża.  


In [4]:

input_dim = output_dim = 256  # Placeholder: input_ids must be embedded properly
model = RLST(input_dim, output_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0003)
loss_fn_ce = nn.MSELoss()
loss_fn_mse = nn.MSELoss()
epoch_len = 10

# Training loop
for epoch in tqdm(range(epoch_len), desc="Nauka", position=0):
    total_loss = 0.0
    # for batch in tqdm(train_dataloader, desc='Batche', position=1, leave=False):
    try:
        for batch_idx, batch in enumerate(train_dataloader):
            x_seq, y_seq = batch
            loss = train_step(model, optimizer, x_seq[0], y_seq[0], loss_fn_ce, loss_fn_mse, device)
            total_loss += loss
            if batch_idx % 20 == 0:
                print(f"Epoch {epoch+1}, Batch {batch_idx}, Loss: {loss:.4f}")
    except:
        pass
    print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

# Inference example
for batch in test_dataloader:
    x_seq, _ = batch
    prediction = translate(model, x_seq[0], device)
    token_ids = [int(torch.argmax(torch.tensor(vec))) for vec in prediction]
    sentence = tokenizer.decode(token_ids, skip_special_tokens=True)
    print("Translated token vectors (first sentence):", sentence)
    break

Nauka:   0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1, Batch 0, Loss: 1018.8916
Epoch 1, Batch 20, Loss: 50.8331
Epoch 1, Batch 40, Loss: 45.8219
Epoch 1, Loss: 4624.7250
Epoch 2, Batch 0, Loss: 50.5770
Epoch 2, Batch 20, Loss: 46.1649
Epoch 2, Batch 40, Loss: 43.3482
Epoch 2, Loss: 2027.1185
Epoch 3, Batch 0, Loss: 47.7526
Epoch 3, Batch 20, Loss: 37.3129
Epoch 3, Batch 40, Loss: 43.9011
Epoch 3, Loss: 1825.9355
Epoch 4, Batch 0, Loss: 39.7344
Epoch 4, Batch 20, Loss: 31.2087
Epoch 4, Batch 40, Loss: 29.2064
Epoch 4, Loss: 1574.1606
Epoch 5, Batch 0, Loss: 27.5856
Epoch 5, Batch 20, Loss: 201.0647
Epoch 5, Batch 40, Loss: 27.0681
Epoch 5, Loss: 2405.8388
Epoch 6, Batch 0, Loss: 48.3826
Epoch 6, Batch 20, Loss: 126.9132
Epoch 6, Batch 40, Loss: 51.0956
Epoch 6, Loss: 2654.1612
Epoch 7, Batch 0, Loss: 109.8427
Epoch 7, Batch 20, Loss: 56.1180
Epoch 7, Batch 40, Loss: 27.7556
Epoch 7, Loss: 1886.2472
Epoch 8, Batch 0, Loss: 28.2765
Epoch 8, Batch 20, Loss: 53.9597
Epoch 8, Batch 40, Loss: 18.9189
Epoch 8, Loss: 2259.8187
Epoch 9, Ba

  x_token = torch.tensor(input_seq[x_i], dtype=torch.float32).to(device)
  x_token = torch.tensor(input_seq[x_i], dtype=torch.float32).to(device)
