In [1]:
import torch
import torch.nn as nn
import math
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from torch.utils.data import DataLoader
import time
import torch.nn.functional as F
from tqdm.auto import tqdm

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
D_MODEL = 256
N_HEAD = 8
N_LAYERS = 4
D_FF = 1024
DROPOUT = 0.1
BATCH_SIZE = 64
LEARNING_RATE = 0.0001
EPOCHS = 70
MAX_LEN = 128
PATIENCE = 10

In [2]:
class Transformer(nn.Module):
    def __init__(self,vocab_size, d_model,  d_ff, n_head, n_layers, dropout):
        super(Transformer,self).__init__()
        self.embedding = Embeddings(d_model, vocab_size)
        self.pos_encoder = PositionalEncoding(d_model,dropout)
        self.encoder = Encoder(d_model, d_ff, n_head, n_layers, dropout)
        self.decoder = Decoder(d_model,  d_ff, n_head, n_layers, dropout)
        self.fc = nn.Linear(d_model, vocab_size)
        self._reset_parameters()

    def forward(self,inputs, outputs, src_mask, tgt_mask):
        embedded_inputs = self.embedding(inputs)
        embedded_outputs = self.embedding(outputs)
        embedded_inputs = self.pos_encoder(embedded_inputs)
        embedded_outputs = self.pos_encoder(embedded_outputs)
        encoding = self.encoder(embedded_inputs, src_mask)
        decoding = self.decoder(embedded_outputs, encoding, tgt_mask, src_mask)
        return self.fc(decoding)
        
    def _reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p) 

            
        
        

In [3]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p = dropout)
        position = torch.arange(max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0,d_model,2).float()*(-math.log(10000.0)/d_model))
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position*div_term)
        pe[:, 1::2] = torch.cos(position*div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
 
    def forward(self,x):
        x = x + self.pe[:,:x.size(1)]
        return self.dropout(x)

In [4]:
class Embeddings(nn.Module):
    def __init__(self, d_model, vocab):
        super(Embeddings, self).__init__()
        self.lut = nn.Embedding(vocab, d_model)
        self.d_model = d_model

    def forward(self, x):
        return self.lut(x) * math.sqrt(self.d_model)

In [5]:
class Decoder(nn.Module):
    def __init__(self, d_model,  d_ff, n_head, n_layers, dropout):
        super(Decoder,self).__init__()
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, d_ff, n_head, dropout) for i in range(n_layers)])

    def forward(self, x, encoder_output, tgt_mask,src_mask):
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, tgt_mask, src_mask)
        return x

In [6]:
class Encoder(nn.Module):
    def __init__(self, d_model,  d_ff, n_head, n_layers, dropout):
        super(Encoder,self).__init__()
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model,  d_ff,n_head, dropout) for _ in range(n_layers)])

    def forward(self, src, src_mask): 
        x = src
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x       
        
        
        

In [7]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, d_ff, n_head, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attention = AttentionLayer(d_model, n_head, dropout)
        self.ffn = FFNLayer(d_model, d_ff, dropout)
        self.cross_attention = AttentionLayer(d_model, n_head, dropout)
        

    def forward(self, x,encoder_output, tgt_mask, src_mask):
        x = self.self_attention(x,x, x, tgt_mask)
        x = self.cross_attention(x, encoder_output, encoder_output, src_mask)
        x = self.ffn(x)
        return x
        

In [8]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model,  d_ff, n_head, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attention = AttentionLayer(d_model, n_head, dropout)
        self.ffn = FFNLayer(d_model, d_ff, dropout)
        

    def forward(self, x, src_mask):
        x = self.self_attention(x,x,x, src_mask)
        x = self.ffn(x)
        return x
        

In [9]:
class FFNLayer(nn.Module):
    def __init__(self, d_model, d_ff, dropout):
        super(FFNLayer,self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(d_ff,d_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.dropout(out)
        out = out +x 
        out = self.layer_norm(out)
        return out
        

In [10]:
class AttentionLayer(nn.Module):
    def __init__(self,d_model, n_head, dropout):
        super(AttentionLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model, n_head)
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q,k,v, mask):
        out = self.mha(q,k,v, mask)
        out = self.dropout(out)
        out = out+q
        out = self.layer_norm(out)
        return out
       
       

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.n_head = n_head
        self.d_k = d_model // n_head
        
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask):
        batch_size = Q.size(0)

        q = self.q_proj(Q).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        k = self.k_proj(K).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
        v = self.v_proj(V).view(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask, -1e9)
        p_attn = scores.softmax(dim=-1)
        attention_output = torch.matmul(p_attn, v)
        attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.n_head * self.d_k)
        return self.out_proj(attention_output)

        

In [12]:
def create_masks(src_batch, tgt_batch, pad_token_id, device):
    src_mask = (src_batch == pad_token_id).unsqueeze(1).unsqueeze(2)
    tgt_pad_mask = (tgt_batch == pad_token_id).unsqueeze(1).unsqueeze(2)
    tgt_len = tgt_batch.size(1)
    tgt_causal_mask = torch.triu(torch.ones((tgt_len, tgt_len), device=device), diagonal=1).bool()
    tgt_mask = tgt_pad_mask | tgt_causal_mask
    return src_mask.to(device), tgt_mask.to(device)

In [13]:
def train_epoch(model, dataloader, optimizer, criterion, pad_token_id, device):
    model.train()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc='Training', leave=False)
    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}
        src = batch['input_ids']
        labels = batch['labels']
        
        tgt = labels.clone()
        tgt[tgt == -100] = pad_token_id 
        shifted_tgt = torch.full_like(tgt, pad_token_id)
        shifted_tgt[:, 1:] = tgt[:, :-1]
        tgt = shifted_tgt
        
        src_mask, tgt_mask = create_masks(src, tgt, pad_token_id, device)
        
        optimizer.zero_grad()
        output = model(src, tgt, src_mask, tgt_mask)
        
        output_flat = output.view(-1, output.shape[-1])
        labels_flat = labels.view(-1)
        loss = criterion(output_flat, labels_flat)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
        
    return total_loss / len(dataloader)

In [14]:
def evaluate(model, dataloader, criterion, pad_token_id, device):
    model.eval()
    total_loss = 0
    progress_bar = tqdm(dataloader, desc='Evaluating', leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            batch = {k: v.to(device) for k, v in batch.items()}
            src = batch['input_ids']
            labels = batch['labels']
            
            tgt = labels.clone()
            tgt[tgt == -100] = pad_token_id 
            shifted_tgt = torch.full_like(tgt, pad_token_id)
            shifted_tgt[:, 1:] = tgt[:, :-1]
            tgt = shifted_tgt
            
            src_mask, tgt_mask = create_masks(src, tgt, pad_token_id, device)
            
            output = model(src, tgt, src_mask, tgt_mask)
            output_flat = output.view(-1, output.shape[-1])
            labels_flat = labels.view(-1)
            loss = criterion(output_flat, labels_flat)
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
            
    return total_loss / len(dataloader)    

In [15]:
dataset = load_dataset("iwslt2017", "iwslt2017-en-it")

tokenizer = AutoTokenizer.from_pretrained("t5-small")


In [16]:
def preprocess(examples):
    src_texts = [ex['en'] for ex in examples['translation']]
    tgt_texts = [ex['it'] for ex in examples['translation']]
    
    model_inputs = tokenizer(
        src_texts,
        text_target=tgt_texts,
        truncation=True,
        max_length=MAX_LEN
    )
    return model_inputs

tokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=dataset["train"].column_names)


In [17]:
data_collator = DataCollatorForSeq2Seq(tokenizer, padding=True, return_tensors = "pt")

train_data = tokenized_dataset["train"].shuffle(seed=42)
valid_data = tokenized_dataset["validation"]

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=data_collator)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE, collate_fn=data_collator)

In [18]:
VOCAB_SIZE = tokenizer.vocab_size
model = Transformer(VOCAB_SIZE, D_MODEL, D_FF, N_HEAD, N_LAYERS, DROPOUT).to(DEVICE)
    
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.98), eps=1e-9)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,      
    'min',          
    factor=0.5,     
    patience=5,     
)

print(f"Model has {sum(p.numel() for p in model.parameters() if p.requires_grad):,} trainable parameters.")

print("Starting training...")
best_valid_loss = float('inf')
epochs_no_improve = 0

for epoch in tqdm(range(1, EPOCHS + 1), desc="Epochs"):
    start_time = time.time()
    
    train_loss = train_epoch(model, train_loader, optimizer, criterion, tokenizer.pad_token_id, DEVICE)
    valid_loss = evaluate(model, valid_loader, criterion, tokenizer.pad_token_id, DEVICE)

    scheduler.step(valid_loss) 
    
    end_time = time.time()
    epoch_mins, epoch_secs = divmod(end_time - start_time, 60)
    
    print(f"--- Epoch {epoch}/{EPOCHS} | Time: {int(epoch_mins)}m {int(epoch_secs)}s ---")
    print(f"\tTrain Loss: {train_loss:.4f}")
    print(f"\tValid Loss: {valid_loss:.4f}")

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "transformer_en_it.pt")
        print("\t-> Saved best model (based on validation loss)")
        epochs_no_improve = 0
    else:
        epochs_no_improve += 1
        print(f"\t-> No improvement in validation loss for {epochs_no_improve} epoch(s).")

    if epochs_no_improve >= PATIENCE:
        print(f"\nEarly stopping triggered after {PATIENCE} epochs with no improvement.")
        break

print("Training finished.")

Model has 23,840,100 trainable parameters.
Starting training...




Epochs:   0%|          | 0/70 [00:00<?, ?it/s]

Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 1/70 | Time: 11m 52s ---
	Train Loss: 3.7457
	Valid Loss: 2.9405
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 2/70 | Time: 11m 57s ---
	Train Loss: 2.7747
	Valid Loss: 2.5316
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 3/70 | Time: 12m 5s ---
	Train Loss: 2.3847
	Valid Loss: 2.1488
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 4/70 | Time: 12m 6s ---
	Train Loss: 2.0488
	Valid Loss: 1.8958
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 5/70 | Time: 12m 5s ---
	Train Loss: 1.8070
	Valid Loss: 1.7021
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 6/70 | Time: 12m 6s ---
	Train Loss: 1.6399
	Valid Loss: 1.5903
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 7/70 | Time: 12m 13s ---
	Train Loss: 1.5215
	Valid Loss: 1.4908
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 8/70 | Time: 12m 25s ---
	Train Loss: 1.4320
	Valid Loss: 1.4284
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 9/70 | Time: 12m 52s ---
	Train Loss: 1.3616
	Valid Loss: 1.3831
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 10/70 | Time: 13m 1s ---
	Train Loss: 1.3038
	Valid Loss: 1.3437
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 11/70 | Time: 12m 39s ---
	Train Loss: 1.2555
	Valid Loss: 1.3109
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 12/70 | Time: 12m 28s ---
	Train Loss: 1.2142
	Valid Loss: 1.2838
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 13/70 | Time: 12m 20s ---
	Train Loss: 1.1782
	Valid Loss: 1.2572
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 14/70 | Time: 12m 15s ---
	Train Loss: 1.1471
	Valid Loss: 1.2389
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 15/70 | Time: 12m 14s ---
	Train Loss: 1.1197
	Valid Loss: 1.2258
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 16/70 | Time: 12m 12s ---
	Train Loss: 1.0964
	Valid Loss: 1.2165
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 17/70 | Time: 12m 9s ---
	Train Loss: 1.0743
	Valid Loss: 1.2060
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 18/70 | Time: 12m 10s ---
	Train Loss: 1.0544
	Valid Loss: 1.1907
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 19/70 | Time: 12m 11s ---
	Train Loss: 1.0380
	Valid Loss: 1.1805
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 20/70 | Time: 12m 11s ---
	Train Loss: 1.0212
	Valid Loss: 1.1717
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 21/70 | Time: 12m 12s ---
	Train Loss: 1.0055
	Valid Loss: 1.1662
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 22/70 | Time: 12m 13s ---
	Train Loss: 0.9903
	Valid Loss: 1.1581
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 23/70 | Time: 12m 13s ---
	Train Loss: 0.9768
	Valid Loss: 1.1518
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 24/70 | Time: 12m 16s ---
	Train Loss: 0.9642
	Valid Loss: 1.1509
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 25/70 | Time: 12m 33s ---
	Train Loss: 0.9526
	Valid Loss: 1.1465
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 26/70 | Time: 12m 33s ---
	Train Loss: 0.9404
	Valid Loss: 1.1421
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 27/70 | Time: 12m 42s ---
	Train Loss: 0.9298
	Valid Loss: 1.1381
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 28/70 | Time: 12m 47s ---
	Train Loss: 0.9202
	Valid Loss: 1.1331
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 29/70 | Time: 12m 52s ---
	Train Loss: 0.9109
	Valid Loss: 1.1350
	-> No improvement in validation loss for 1 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 30/70 | Time: 12m 33s ---
	Train Loss: 0.9013
	Valid Loss: 1.1282
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 31/70 | Time: 12m 15s ---
	Train Loss: 0.8929
	Valid Loss: 1.1262
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 32/70 | Time: 12m 10s ---
	Train Loss: 0.8843
	Valid Loss: 1.1211
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 33/70 | Time: 12m 10s ---
	Train Loss: 0.8762
	Valid Loss: 1.1228
	-> No improvement in validation loss for 1 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 34/70 | Time: 12m 15s ---
	Train Loss: 0.8692
	Valid Loss: 1.1255
	-> No improvement in validation loss for 2 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 35/70 | Time: 12m 32s ---
	Train Loss: 0.8622
	Valid Loss: 1.1224
	-> No improvement in validation loss for 3 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 36/70 | Time: 12m 37s ---
	Train Loss: 0.8548
	Valid Loss: 1.1200
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 37/70 | Time: 12m 44s ---
	Train Loss: 0.8482
	Valid Loss: 1.1206
	-> No improvement in validation loss for 1 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 38/70 | Time: 12m 50s ---
	Train Loss: 0.8423
	Valid Loss: 1.1186
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 39/70 | Time: 13m 22s ---
	Train Loss: 0.8362
	Valid Loss: 1.1173
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 40/70 | Time: 13m 12s ---
	Train Loss: 0.8301
	Valid Loss: 1.1238
	-> No improvement in validation loss for 1 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 41/70 | Time: 12m 56s ---
	Train Loss: 0.8247
	Valid Loss: 1.1091
	-> Saved best model (based on validation loss)


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 42/70 | Time: 12m 54s ---
	Train Loss: 0.8183
	Valid Loss: 1.1184
	-> No improvement in validation loss for 1 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 43/70 | Time: 13m 10s ---
	Train Loss: 0.8138
	Valid Loss: 1.1134
	-> No improvement in validation loss for 2 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 44/70 | Time: 13m 7s ---
	Train Loss: 0.8081
	Valid Loss: 1.1127
	-> No improvement in validation loss for 3 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 45/70 | Time: 13m 3s ---
	Train Loss: 0.8035
	Valid Loss: 1.1171
	-> No improvement in validation loss for 4 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/15 [00:00<?, ?it/s]

--- Epoch 46/70 | Time: 12m 55s ---
	Train Loss: 0.7985
	Valid Loss: 1.1141
	-> No improvement in validation loss for 5 epoch(s).


Training:   0%|          | 0/3620 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [19]:
import torch.nn.functional as F

def inference(model, tokenizer, input_text, max_len=128, temperature=0.5, device='cpu'):
    model.to(device)
    model.eval()
    tokenized_input = tokenizer(input_text, return_tensors="pt")
    src = tokenized_input['input_ids'].to(device)

    src_mask = (src == tokenizer.pad_token_id).unsqueeze(1).unsqueeze(2).to(device)
    
    output_tokens = torch.tensor([[tokenizer.pad_token_id]], device=device).long()

    with torch.no_grad():
        embedded_src = model.pos_encoder(model.embedding(src))
        memory = model.encoder(embedded_src, src_mask)

        for i in range(max_len - 1):
            tgt_len = output_tokens.size(1)
            tgt_mask = torch.triu(torch.ones((tgt_len, tgt_len), device=device), diagonal=1).bool()

            embedded_tgt = model.pos_encoder(model.embedding(output_tokens))
            decoding_output = model.decoder(embedded_tgt, memory, tgt_mask, src_mask)

            logits = model.fc(decoding_output)
            last_token_logits = logits[:, -1, :] 

            #probs = F.softmax(last_token_logits / temperature, dim=-1)
            #pred_token = torch.multinomial(probs, num_samples=1).item()
            pred_token = torch.argmax(last_token_logits, dim=-1).item() 
            
            
            if pred_token == tokenizer.eos_token_id:
                break
            output_tokens = torch.cat(
                (output_tokens, torch.tensor([[pred_token]], device=device).long()), dim=1
            )
    return tokenizer.decode(output_tokens[0], skip_special_tokens=True)

In [20]:
inference_model = Transformer(VOCAB_SIZE, D_MODEL, D_FF, N_HEAD, N_LAYERS, DROPOUT)
inference_model.load_state_dict(torch.load("transformer_en_it.pt", map_location=DEVICE))
english = "this is a great achievement"
italian = inference(inference_model, tokenizer, english,device = DEVICE)
print(italian)

è un gran risultato.
