In [None]:
!pip install datasets
!pip install torchtext

In [None]:
from datasets import load_dataset

dataset = load_dataset("wmt17", "ru-en")

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ru", use_fast=True)


def collate_fn(batch):
    src_texts = [item['translation']['en'] for item in batch]
    tgt_texts = [item['translation']['ru'] for item in batch]

    src_encodings = tokenizer(src_texts, truncation=True, padding="longest", max_length=128, return_tensors="pt", add_special_tokens=True)
    tgt_encodings = tokenizer(tgt_texts, truncation=True, padding="longest", max_length=128, return_tensors="pt", add_special_tokens=True)

    input_ids = src_encodings['input_ids']     
    target_ids = tgt_encodings['input_ids']

    return input_ids, target_ids

print(tokenizer.vocab_size)

In [None]:
print(tokenizer.eos_token_id)
print(tokenizer.bos_token_id)
tokenizer.special_tokens_map

In [None]:
from torch.utils.data import DataLoader
import numpy as np

train_dataset = dataset["train"].shuffle(seed=42).select(range(500000)) 

BATCH_SIZE = 32

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_fn, num_workers=4)

In [None]:
import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-(math.log(10000.0) / d_model)))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('pe', pe)

    def forward(self, x):

        seq_len = x.size(0)

        pe = self.pe[:seq_len].unsqueeze(1)
        return x + pe.to(x.device)


class TranslationTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3, dim_feedforward=512, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.src_embedding = nn.Embedding(vocab_size, d_model, padding_idx=tokenizer.pad_token_id)
        self.tgt_embedding = nn.Embedding(vocab_size, d_model, padding_idx=tokenizer.pad_token_id)
        self.positional_encoding = PositionalEncoding(d_model)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead,
                                          num_encoder_layers=num_encoder_layers,
                                          num_decoder_layers=num_decoder_layers,
                                          dim_feedforward=dim_feedforward,
                                          dropout=dropout,
                                          batch_first=False
                                         )
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, src, tgt, src_key_padding_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        src_emb = self.src_embedding(src) * math.sqrt(self.d_model) 
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)

        src_emb = self.positional_encoding(src_emb)
        tgt_emb = self.positional_encoding(tgt_emb)

        tgt_mask = self.transformer.generate_square_subsequent_mask(tgt_emb.size(0)).to(src.device)

        out = self.transformer(src_emb, tgt_emb,
                               tgt_mask=tgt_mask,
                               src_key_padding_mask=src_key_padding_mask,
                               tgt_key_padding_mask=tgt_key_padding_mask,
                               memory_key_padding_mask=memory_key_padding_mask)
        return self.fc_out(out)

In [None]:
import torch.optim as optim
from tqdm import tqdm

vocab_size = tokenizer.vocab_size

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = TranslationTransformer(vocab_size=vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

train_losses = []
valid_losses = []
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler() if torch.cuda.is_available() else None

def make_padding_mask(seq):
    # seq: (seq_len, batch)
    return (seq == tokenizer.pad_token_id).transpose(0, 1)  # (batch, seq_len)

from torch.optim.lr_scheduler import LambdaLR

# def get_inverse_sqrt_scheduler(optimizer, d_model=256, warmup_steps=4000, scale=1e-3):
#     def lr_lambda(step):
#         step += 1
#         return (d_model ** -0.5) * min(step ** -0.5, step * warmup_steps ** -1.5) * scale 
#     return LambdaLR(optimizer, lr_lambda)

# warmup_steps = 2000
# scheduler = get_inverse_sqrt_scheduler(optimizer, warmup_steps)

def train_epoch(model, dataloader, optimizer, criterion, device, scheduler=None, grad_clip=1.0):
    model.train()
    total_loss = 0
    for src_batch, tgt_batch in tqdm(dataloader, desc="Training"):
        print(tokenizer.convert_ids_to_tokens(tgt_batch[0]))
        # src_batch, tgt_batch: (batch, seq_len)
        src = src_batch.transpose(0,1).to(device)   # (src_len, batch)
        tgt = tgt_batch.transpose(0,1).to(device)   # (tgt_len, batch)

        # prepare decoder input: decoder sees all tokens except last
        tgt_input = tgt[:-1, :]   # (tgt_len-1, batch)
        tgt_expected = tgt[1:, :] # (tgt_len-1, batch)

        src_key_padding_mask = make_padding_mask(src).to(device)  # (batch, src_len)
        tgt_key_padding_mask = make_padding_mask(tgt_input).to(device)  # (batch, tgt_len-1)

        optimizer.zero_grad()
        if scaler is not None:
            with autocast():
                output = model(src, tgt_input, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask)
                # output: (tgt_len-1, batch, vocab)
                loss = criterion(output.view(-1, output.size(-1)), tgt_expected.contiguous().view(-1))
            scaler.scale(loss).backward()
            # if grad_clip is not None:
            #     scaler.unscale_(optimizer)
            #     torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(src, tgt_input, src_key_padding_mask=src_key_padding_mask, tgt_key_padding_mask=tgt_key_padding_mask)
            loss = criterion(output.view(-1, output.size(-1)), tgt_expected.contiguous().view(-1))
            loss.backward()
            # if grad_clip is not None:
            #     torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            optimizer.step()
            
        # if scheduler is not None:
        #     scheduler.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)


In [None]:
train_losses = []
num_epochs = 3
for epoch in range(num_epochs):
    loss = train_epoch(model, train_dataloader, optimizer, criterion, device)
    train_losses.append(loss)
    print(f"Epoch {epoch+1}: Train Loss={loss:.4f}")

In [None]:
import matplotlib.pyplot as plt

plt.plot(range(1, num_epochs+1), train_losses, marker="o")
plt.xlabel("Epoch")
plt.ylabel("Train Loss")
plt.title("Training Loss")
plt.grid(True)
plt.show()

In [None]:
torch.save(model.state_dict(), '/kaggle/working/translation_transformer.pth')

In [None]:
def greedy_decode_batch(model, src, tokenizer, device, max_len=100):
    model.eval()
    with torch.no_grad():
        batch_size = src.size(0)
        src = src.transpose(0, 1).to(device)  # (src_len, batch)
        src_key_padding_mask = (src.transpose(0,1) == tokenizer.pad_token_id).to(device)  # (batch, src_len)

        bos = tokenizer.bos_token_id or tokenizer.cls_token_id or tokenizer.pad_token_id
        eos = tokenizer.eos_token_id or tokenizer.sep_token_id or tokenizer.pad_token_id

        generated = torch.full((1, batch_size), bos, dtype=torch.long, device=device)  # (1, batch)

        finished = torch.zeros(batch_size, dtype=torch.bool, device=device)

        for _ in range(max_len):
            tgt_mask = model.transformer.generate_square_subsequent_mask(generated.size(0)).to(device)
            tgt_key_padding_mask = (generated.transpose(0,1) == tokenizer.pad_token_id)  # (batch, seq_len)

            out = model(src, generated,
                        src_key_padding_mask=src_key_padding_mask,
                        tgt_key_padding_mask=tgt_key_padding_mask)

            next_tokens = out[-1].argmax(dim=-1)  # (batch,)
            generated = torch.cat([generated, next_tokens.unsqueeze(0)], dim=0)

            finished |= (next_tokens == eos)
            if finished.all():
                break

        results = []
        generated = generated.transpose(0,1).cpu().tolist()  # (batch, seq_len)
        for seq in generated:
            results.append(tokenizer.decode(seq, skip_special_tokens=True))
        return results


In [None]:
!pip install sacrebleu

In [None]:
import sacrebleu

def test_model(model, dataloader, tokenizer, device, max_len=100):
    all_hypotheses = []
    all_references = []

    for src_batch, tgt_batch in tqdm(dataloader, desc="Testing"):
        hypotheses = greedy_decode_batch(model, src_batch, tokenizer, device, max_len=max_len)

        for ref_ids in tgt_batch.tolist():
            all_references.append(tokenizer.decode(ref_ids, skip_special_tokens=True))
        all_hypotheses.extend(hypotheses)

    bleu = sacrebleu.corpus_bleu(all_hypotheses, [all_references])

    return bleu.score, all_references[:5], all_hypotheses[:5]

In [None]:
test_dataset = dataset['test']
test_dataloader = DataLoader(test_dataset, batch_size=1,
                              shuffle=True, collate_fn=collate_fn, num_workers=4)

bleu, refs_sample, hyps_sample = test_model(model, test_dataloader, tokenizer, device, max_len=60)
print(f"\nFinal BLEU score on test set: {bleu:.2f}")

print("\nПримеры перевода:")
for r, h in zip(refs_sample[:5], hyps_sample[:5]):
    print(f"REF: {r}")
    print(f"HYP: {h}")

In [None]:
if not model:
  model = TranslationTramsformer(tokenizer.vocab_size)
  model.load_state_dict(torch.load('/content/my_model.pth'))

In [None]:
def translate_sentence(model, tokenizer, sentence, device, max_len=50):
    model.eval()

    src_enc = tokenizer(
        [sentence],
        return_tensors="pt",
        truncation=True,
        padding="longest",
        max_length=128
    ).to(device)

    src_ids = src_enc["input_ids"].transpose(0, 1)  # [seq_len, batch]
    src_pad_mask = (src_ids == tokenizer.pad_token_id).transpose(0, 1).to(device)

    start_token = tokenizer.bos_token_id or tokenizer.cls_token_id or tokenizer.pad_token_id
    eos_token = tokenizer.eos_token_id or tokenizer.sep_token_id or tokenizer.pad_token_id

    generated_tokens = [start_token]
    for _ in range(max_len):
        tgt_input = torch.tensor(generated_tokens, dtype=torch.long, device=device).unsqueeze(1)
        tgt_pad_mask = (tgt_input == tokenizer.pad_token_id).transpose(0, 1)

        with torch.no_grad():
            output = model(src_ids, tgt_input, src_pad_mask, tgt_pad_mask)

        next_token = output[-1, 0].argmax().item()
        if next_token == eos_token:
            break
        generated_tokens.append(next_token)

    translation = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    return translation


In [None]:
example_sentences = [
    "Hello, how are you?",
    "The transformer model is very powerful for sequence-to-sequence tasks.",
    "In 2017, researchers proposed the architecture known as 'Attention Is All You Need'.",
    "Where are you from?"
]

for s in example_sentences:
    print("EN:", s)
    print("RU:", translate_sentence(model, tokenizer, s, device))
    print()