In [None]:
import math

import torch
import torch.nn as nn
from torch.nn.functional import softmax

import os
import spacy
import urllib.request
import zipfile
from torch.utils.data import Dataset, DataLoader

In [None]:
# try to finish this function on your own
def scaled_dot_product_attention(query, key, value, mask=None):
    """
    Args:
        query: (batch_size, num_heads, seq_len_q, d_k)
        key: (batch_size, num_heads, seq_len_k, d_k)
        value: (batch_size, num_heads, seq_len_v, d_v)
        mask: Optional mask to prevent attention to certain positions
    """
    # get the size of d_k using the query or the key
    
    # calculate the attention score using the formula given. Be vary of the dimension of Q and K. And what you need to transpose to achieve the desired results.

    #YOUR CODE HERE
    d_k = query.shape[-1]

    scores = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(d_k)

    # hint 1: batch_size and num_heads should not change
    # hint 2: nXm @ mXn -> nXn, but you cannot do nXm @ nXm, the right dimension of the left matrix should match the left dimension of the right matrix. The easy way I visualize it is as, who face each other must be same

    # add inf is a mask is given, This is used for the decoder layer. You can use help for this if you want to. I did!!
    #YOUR CODE HERE
    if mask is not None:
        # 确保掩码的形状与scores匹配
        # scores形状：[batch_size, num_heads, seq_len_q, seq_len_k]
        # 注意：掩码可能需要广播到这个形状
        if mask.dim() == 3:  # [batch_size, 1, seq_len]
            # 这是padding掩码，需要扩展维度
            mask = mask.unsqueeze(1)  # [batch_size, 1, 1, seq_len]
        elif mask.dim() == 4 and mask.size(1) == 1:  # [batch_size, 1, 1, seq_len]
            # 这是已经正确形状的padding掩码
            pass
        elif mask.dim() == 4 and mask.size(1) != 1:  # [batch_size, ?, ?, ?]
            # 这可能是已经准备好的掩码
            pass
            
        # 使用masked_fill_前检查形状
        if mask.size(-1) != scores.size(-1) or mask.size(-2) != scores.size(-2):
            # 如果掩码和分数的最后两个维度不匹配，我们需要截断或填充掩码
            # 为安全起见，我们将掩码调整为与分数相同的大小
            new_mask = torch.ones_like(scores, dtype=torch.bool, device=scores.device)
            
            # 使用两个张量中较小的尺寸进行复制
            min_dim_2 = min(mask.size(-2), scores.size(-2))
            min_dim_3 = min(mask.size(-1), scores.size(-1))
            
            # 将旧掩码的内容复制到新掩码中
            new_mask[:, :, :min_dim_2, :min_dim_3] = mask[:, :, :min_dim_2, :min_dim_3]
            mask = new_mask
        
        scores.masked_fill(mask == 0, float('-inf'))

    # get the attention weights by taking a softmax on the scores, again be wary of the dimensions. You do not want to take softmax of batch_size or num_heads. Only of the values. How can you do that?
    #YOUR CODE HERE
    attention_weights = softmax(scores, dim=-1)

    # return the attention by multiplying the attention weights with the Value (V)
    #YOUR CODE HERE
    return torch.matmul(attention_weights, value)


In [None]:
class MultiHeadAttention(nn.Module):
    #Let me write the initializer just for this class, so you get an idea of how it needs to be done
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads" #think why?

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # Note: use integer division //

        # Create the learnable projection matrices
        self.W_q = nn.Linear(d_model, d_model) #think why we are doing from d_model -> d_model
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    @staticmethod
    def scaled_dot_product_attention(query, key, value, mask=None):
        #YOUR IMPLEMENTATION HERE
        d_k = query.shape[-1]

        scores = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(d_k)
        if mask is not None:
            scores.masked_fill(mask == 0, float('-inf'))
        
        attention_weights = softmax(scores, dim=-1)
        return torch.matmul(attention_weights, value)
        
    def forward(self, query, key, value, mask=None):
        #get batch_size and sequence length
        #YOUR CODE HERE
        batch_size = query.shape[0]
        q_seq_len = query.shape[1]  
        k_seq_len = key.shape[1]    
        v_seq_len = value.shape[1]  

        # 1. Linear projections
        #YOUR CODE HERE
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)

        # 2. Split into heads
        #YOUR CODE HERE
        Q = Q.view(batch_size, q_seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, k_seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, v_seq_len, self.num_heads, self.d_k).transpose(1, 2)

        # 3. Apply attention
        #YOUR CODE HERE
        output = scaled_dot_product_attention(Q, K, V, mask)

        # 4. Concatenate heads
        #YOUR CODE HERE
        output = output.transpose(1, 2).contiguous().view(batch_size, q_seq_len, self.d_model)

        # 5. Final projection
        #YOUR CODE HERE
        return self.W_o(output)
        

In [None]:
class FeedForwardNetwork(nn.Module):
    """Position-wise Feed-Forward Network

    Args:
        d_model: input/output dimension
        d_ff: hidden dimension
        dropout: dropout rate (default=0.1)
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        #create a sequential ff model as mentioned in section 3.3
        #YOUR CODE HERE
        self.model = nn.Sequential(nn.Linear(d_model, d_ff),
                      nn.ReLU(),
                      nn.Dropout(dropout),
                      nn.Linear(d_ff, d_model),
                      nn.Dropout(dropout))

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
        Returns:
            Output tensor of shape (batch_size, seq_len, d_model)
        """
        #YOUR CODE HERE
        return self.model(x)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000):
        super().__init__()

        # Create matrix of shape (max_seq_length, d_model)
        #YOUR CODE HERE
        pe = torch.zeros(max_seq_length, d_model)

        # Create position vector
        #YOUR CODE HERE
        position = torch.arange(0, max_seq_length).unsqueeze(1)

        # Create division term
        #YOUR CODE HERE
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        # Compute positional encodings
        #YOUR CODE HERE
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register buffer
        #YOUR CODE HERE
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        """
        Args:
            x: Tensor shape (batch_size, seq_len, d_model)
        """
        return x + self.pe[:, :x.size(1)]

In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # 1. Multi-head attention
        #YOUR CODE HERE
        self.mha = MultiHeadAttention(d_model, num_heads)
        # 2. Layer normalization
        #YOUR CODE HERE
        self.layer_norm_1 = nn.LayerNorm(d_model)

        # 3. Feed forward
        #YOUR CODE HERE
        self.ff = FeedForwardNetwork(d_model, d_ff, dropout)

        # 4. Another layer normalization
        #YOUR CODE HERE
        self.layer_norm_2 = nn.LayerNorm(d_model)

        # 5. Dropout
        #YOUR CODE HERE
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tensor of shape (batch_size, seq_len, d_model)
            mask: Optional mask for padding
        Returns:
            x: Output tensor of shape (batch_size, seq_len, d_model)
        """
        # 1. Multi-head attention with residual connection and layer norm
        #YOUR CODE HERE
        attr_output = self.mha(x, x, x, mask)
        x = self.dropout(x + attr_output)
        x = self.layer_norm_1(x)
        
        ff_output = self.ff(x)
        x = self.dropout(x + ff_output)
        x = self.layer_norm_2(x)
        # 2. Feed forward with residual connection and layer norm
        #YOUR CODE HERE
        return x


In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()

        # 1. Masked Multi-head attention
        #YOUR CODE HERE
        self.mha_1 = MultiHeadAttention(d_model, num_heads)

        # 2. Layer norm for first sub-layer
        #YOUR CODE HERE
        self.layer_norm_1 = nn.LayerNorm(d_model)

        # 3. Multi-head attention for cross attention with encoder output
        # This will take encoder output as key and value
        #YOUR CODE HERE
        self.mha_2 = MultiHeadAttention(d_model, num_heads)
        # 4. Layer norm for second sub-layer
        #YOUR CODE HERE
        self.layer_norm_2 = nn.LayerNorm(d_model)

        # 5. Feed forward network
        #YOUR CODE HERE
        self.ff = FeedForwardNetwork(d_model, d_ff, dropout)

        # 6. Layer norm for third sub-layer
        #YOUR CODE HERE
        self.layer_norm_3 = nn.LayerNorm(d_model)

        # 7. Dropout
        #YOUR CODE HERE
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Target sequence embedding (batch_size, target_seq_len, d_model)
            encoder_output: Output from encoder (batch_size, source_seq_len, d_model)
            src_mask: Mask for source padding
            tgt_mask: Mask for target padding and future positions
        """
        # 1. Masked self-attention
        # Remember: In decoder self-attention, query, key, value are all x
        #YOUR CODE HERE
        attr_output = self.mha_1(x, x, x, tgt_mask)
        x = self.dropout(x + attr_output)
        x = self.layer_norm_1(x)
        
        attr_output_2 = self.mha_2(x, encoder_output, encoder_output, src_mask)
        x = self.dropout(x + attr_output_2)
        x = self.layer_norm_2(x)

        ff_output = self.ff(x)
        x = self.dropout(x + ff_output)
        x = self.layer_norm_3(x)

        return x

In [None]:
class Encoder(nn.Module):
    def __init__(self,
                 vocab_size,
                 d_model,
                 num_layers=6,
                 num_heads=8,
                 d_ff=2048,
                 dropout=0.1,
                 max_seq_length=5000):
        super().__init__()

        # 1. Input embedding
        #YOUR CODE HERE
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.scale = math.sqrt(d_model)

        # 2. Positional encoding
        #YOUR CODE HERE
        self.pe = PositionalEncoding(d_model, max_seq_length)

        # 3. Dropout
        #YOUR CODE HERE
        self.dropout = nn.Dropout(dropout)

        # 4. Stack of N encoder layers
        #YOUR CODE HERE
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])

    def forward(self, x, mask=None):
        """
        Args:
            x: Input tokens (batch_size, seq_len)
            mask: Mask for padding positions
        Returns:
            encoder_output: (batch_size, seq_len, d_model)
        """
        # 1. Pass through embedding layer and scale
        #YOUR CODE HERE
        x = self.embedding(x) * self.scale

        # 2. Add positional encoding and apply dropout
        #YOUR CODE HERE
        x = self.dropout(self.pe(x))

        # 3. Pass through each encoder layer
        #YOUR CODE HERE
        for layer in self.encoder_layers:
            x = layer(x, mask)

        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self,
                 vocab_size,
                 d_model,
                 num_layers=6,
                 num_heads=8,
                 d_ff=2048,
                 dropout=0.1,
                 max_seq_length=5000):
        super().__init__()

        # 1. Output embedding
        self.embeddings = nn.Embedding(vocab_size, d_model)
        self.scale = math.sqrt(d_model)

        # 2. Positional encoding
        self.pe = PositionalEncoding(d_model, max_seq_length)

        # 3. Dropout
        self.dropout = nn.Dropout(dropout)

        # 4. Stack of N decoder layers
        self.decoder_layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x, encoder_output, src_mask=None, tgt_mask=None):
        """
        Args:
            x: Target tokens (batch_size, target_seq_len)
            encoder_output: Output from encoder (batch_size, source_seq_len, d_model)
            src_mask: Mask for source padding
            tgt_mask: Mask for target padding and future positions
        Returns:
            decoder_output: (batch_size, target_seq_len, d_model)
        """
        # 1. Pass through embedding layer and scale
        x = self.embeddings(x) * self.scale

        # 2. Add positional encoding and dropout
        x = self.dropout(self.pe(x))

        # 3. Pass through each decoder layer
        for layer in self.decoder_layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)

        return x

In [None]:
def create_padding_mask(seq):
    """
    Create mask for padding tokens (0s)
    Args:
        seq: Input sequence tensor (batch_size, seq_len)
    Returns:
        mask: Padding mask (batch_size, 1, 1, seq_len)
    """
    batch_size, seq_len = seq.shape
    output = (seq == 0)
    return output.view(batch_size, 1, 1, seq_len)

def create_future_mask(size, device):
    """
    Create mask to prevent attention to future positions
    Args:
        size: Size of square mask (target_seq_len)
    Returns:
        mask: Future mask (1, 1, size, size)
    """
    # Create upper triangular matrix and invert it
    mask = torch.triu(torch.ones((1, 1, size, size), device=device), diagonal=1) == 0
    return mask

def create_masks(src, tgt):
    """
    Create all masks needed for training
    Args:
        src: Source sequence (batch_size, src_len)
        tgt: Target sequence (batch_size, tgt_len)
    Returns:
        src_mask: Padding mask for encoder
        tgt_mask: Combined padding and future mask for decoder
    """
    device = src.device
    # 1. Create padding masks
    src_padding_mask = create_padding_mask(src)
    tgt_padding_mask = create_padding_mask(tgt)

    # 2. Create future mask
    tgt_len = tgt.size(1)
    tgt_future_mask = create_future_mask(tgt_len, device)

    # 3. Combine padding and future mask for target
    # Both masks should be True for allowed positions
    tgt_mask = tgt_padding_mask & tgt_future_mask

    return src_padding_mask, tgt_mask

In [None]:
class Transformer(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 tgt_vocab_size,
                 d_model,
                 num_layers=6,
                 num_heads=8,
                 d_ff=2048,
                 dropout=0.1,
                 max_seq_length=5000):
        super().__init__()

        # Pass all necessary parameters to Encoder and Decoder
        self.encoder = Encoder(
            src_vocab_size,
            d_model,
            num_layers,
            num_heads,
            d_ff,
            dropout,
            max_seq_length
        )

        self.decoder = Decoder(
            tgt_vocab_size,
            d_model,
            num_layers,
            num_heads,
            d_ff,
            dropout,
            max_seq_length
        )

        # The final linear layer should project from d_model to tgt_vocab_size
        self.final_layer = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        # 确保输入在同一设备上
        device = src.device
        
        # 如果掩码未提供，则创建掩码
        if src_mask is None or tgt_mask is None:
            src_mask, tgt_mask = create_masks(src, tgt)
        
        # 确保掩码在正确的设备上
        src_mask = src_mask.to(device)
        tgt_mask = tgt_mask.to(device)
        
        # 原有的前向传播逻辑
        encoder_output = self.encoder(src, src_mask)
        decoder_output = self.decoder(tgt, encoder_output, src_mask, tgt_mask)
        output = self.final_layer(decoder_output)
        
        return output

In [None]:
# class TransformerLRScheduler:
#     def __init__(self, optimizer, d_model, warmup_steps):
#         """
#         Args:
#             optimizer: Optimizer to adjust learning rate for
#             d_model: Model dimensionality
#             warmup_steps: Number of warmup steps
#         """
#         self.optimizer = optimizer
#         self.d_model = d_model
#         self.warmup_steps = warmup_steps


#     def step(self, step_num):
#         """
#         Update learning rate based on step number
#         """
#         # lrate = d_model^(-0.5) * min(step_num^(-0.5), step_num * warmup_steps^(-1.5))
#         #YOUR CODE HERE
#         lrate = torch.pow(self.d_model, -0.5) * torch.min(torch.pow(step_num, -0.5),
#                                                           torch.tensor(step_num) * 
#                                                           torch.pow(self.warmup_steps, -1.5))
class TransformerLRScheduler:
    def __init__(self, optimizer, d_model, warmup_steps):
        """
        Args:
            optimizer: Optimizer to adjust learning rate for
            d_model: Model dimensionality
            warmup_steps: Number of warmup steps
        """
        self.optimizer = optimizer
        self.d_model = d_model
        self.warmup_steps = warmup_steps
        self.step_num = 0  # 内部维护步数，确保从1开始

    def step(self):
        """
        Update learning rate based on step number
        """
        self.step_num += 1  # 步数从1开始，避免初始为0导致数值问题
        lr = (self.d_model ** -0.5) * min(
            self.step_num ** -0.5, 
            self.step_num * (self.warmup_steps ** -1.5)
        )
        # 更新优化器的学习率（支持多参数组）
        for param_group in self.optimizer.param_groups:
            param_group["lr"] = lr

class LabelSmoothing(nn.Module):
    def __init__(self, smoothing=0.1):
        super().__init__()
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing

    def forward(self, logits, target):
        """
        Args:
            logits: Model predictions (batch_size, vocab_size) #each row of vocab_size contains probability score of each label
            target: True labels (batch_size) #each row of batch size contains the index to the correct label
        """
        #Note: make sure to not save the gradients of these
        # Create a soft target distribution
        #create the zeros [0,0,...]
        #fill with calculated value [0.000125..,0.000125...] (this is an arbitarary value for example purposes)
        #add 1 to the correct index (read more on docs of pytorch)
        #return cross entropy loss
        vocab_size = logits.size(-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(logits)
            true_dist.fill_(self.smoothing / (vocab_size - 1))
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
        return torch.mean(torch.sum(-true_dist * torch.log_softmax(logits, dim=-1), dim=-1))

In [None]:
def train_transformer(model, train_dataloader, criterion, optimizer, scheduler, num_epochs, device='cuda'):
    """
    Training loop for transformer

    Args:
        model: Transformer model
        train_dataloader: DataLoader for training data
        criterion: Loss function (with label smoothing)
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        num_epochs: Number of training epochs
    """
    # 1. Setup
    model = model.to(device)
    model.train()

    total_loss = 0
    all_losses = []
    # 2. Training loop
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        epoch_loss = 0

        for batch_idx, batch in enumerate(train_dataloader):
            src = batch['src'].to(device)
            tgt = batch['tgt'].to(device)
            src_mask, tgt_mask = create_masks(src, tgt)

            # 确保掩码也在正确的设备上
            src_mask = src_mask.to(device)
            tgt_mask = tgt_mask.to(device)

            tgt_input = tgt[:, :-1]
            tgt_output = tgt[:, 1:]

            optimizer.zero_grad()
            outputs = model(src, tgt_input, src_mask, tgt_mask)
            outputs = outputs.reshape(-1, outputs.size(-1))
            tgt_output = tgt_output.reshape(-1)

            loss = criterion(outputs, tgt_output)

            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()
            scheduler.step()

            epoch_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")

        avg_epoch_loss = epoch_loss / len(train_dataloader)
        all_losses.append(avg_epoch_loss)
        print(f"Epoch {epoch + 1} Loss: {avg_epoch_loss:.4f}")

        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_epoch_loss,
        }, f'checkpoint_{epoch + 1}.pt')

    return all_losses        
            

In [None]:

def download_multi30k():
    if not os.path.exists('data'):
        os.makedirs('data')
    
    base_url = "https://raw.githubusercontent.com/multi30k/dataset/master/data/task1/raw/"
    files = {
        "train.de": "train.de.gz",
        "train.en": "train.en.gz",
        "val.de": "val.de.gz",
        "val.en": "val.en.gz",
        "test.de": "test_2016_flickr.de.gz",
        "test.en": "test_2016_flickr.en.gz",
    }

    for local_name, remote_name in files.items():
        filepath = os.path.join('data', local_name)
        if not os.path.exists(filepath):
            url = base_url + remote_name
            urllib.request.urlretrieve(url, filepath)
            os.system(f'gunzip -f {filepath}.gz')
        
def load_data(filename):
    # 先检查文件是否为gzip格式
    try:
        with open(filename, 'rb') as test_f:
            is_gzip = test_f.read(2) == b'\x1f\x8b'
        
        if is_gzip:
            # 如果是gzip格式，使用gzip模块打开
            import gzip
            with gzip.open(filename, 'rt', encoding='utf-8') as f:
                return [line.strip() for line in f]
        else:
            # 如果不是gzip格式，正常打开
            with open(filename, 'r', encoding='utf-8') as f:
                return [line.strip() for line in f]
    except Exception as e:
        print(f"读取文件 {filename} 时出错: {e}")
        return []
    
def create_dataset():
    download_multi30k()

    train_de = load_data('data/train.de')
    train_en = load_data('data/train.en')
    val_de = load_data('data/val.de')
    val_en = load_data('data/val.en')

    return (train_de, train_en), (val_de, val_en)


class TranslationDataset(Dataset):
    def __init__(self, src_texts, tgt_texts, src_vocab, tgt_vocab, src_tokenizer, tgt_tokenizer):
        self.src_texts = src_texts
        self.tgt_texts = tgt_texts
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.src_tokenizer = src_tokenizer
        self.tgt_tokenizer = tgt_tokenizer

    def __len__(self):
        return len(self.src_texts)
    
    def __getitem__(self, idx):
        src_text = self.src_texts[idx]
        tgt_text = self.tgt_texts[idx]

        src_tokens = [tok.text for tok in self.src_tokenizer(src_text)]
        tgt_tokens = [tok.text for tok in self.tgt_tokenizer(tgt_text)]

        src_indices = [self.src_vocab["<s>"]] + [self.src_vocab.get(token, self.src_vocab["<unk>"]) for token in src_tokens] + [self.src_vocab["</s>"]]
        tgt_indices = [self.tgt_vocab["<s>"]] + [self.tgt_vocab.get(token, self.tgt_vocab["<unk>"]) for token in tgt_tokens] + [self.tgt_vocab["</s>"]]

        return {
            'src': torch.tensor(src_indices),
            'tgt': torch.tensor(tgt_indices),
        }
    

def build_vocab_from_texts(texts, tokenizer, min_freq=2):
    counter = {}
    for text in texts:
        for token in [tok.text for tok in tokenizer(text)]:
            counter[token] = counter.get(token, 0) + 1

    vocab = {"<s>": 0, "</s>": 1, "<blank>": 2, "<unk>": 3}
    idx = 4
    for word, freq in counter.items():
        if freq >= min_freq:
            vocab[word] = idx
            idx += 1
    return vocab


def create_dataloaders(batch_size=32):
    spacy_de = spacy.load("de_core_news_sm")
    spacy_en = spacy.load("en_core_web_sm")

    (train_de, train_en), (val_de, val_en) = create_dataset()

    vocab_src = build_vocab_from_texts(train_de, spacy_de)
    vocab_tgt = build_vocab_from_texts(train_en, spacy_en)

    train_dataset = TranslationDataset(
        train_de, train_en,
        vocab_src, vocab_tgt,
        spacy_de, spacy_en
    )

    val_dataset = TranslationDataset(
        val_de, val_en,
        vocab_src, vocab_tgt,
        spacy_de, spacy_en
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=collate_batch
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=collate_batch
    )

    return train_dataloader, val_dataloader, vocab_src, vocab_tgt

def collate_batch(batch):
    src_tensors = [item['src'] for item in batch]
    tgt_tensors = [item['tgt'] for item in batch]

    src_padded = torch.nn.utils.rnn.pad_sequence(src_tensors, batch_first=True, padding_value=2)
    tgt_padded = torch.nn.utils.rnn.pad_sequence(tgt_tensors, batch_first=True, padding_value=2)

    return {
        'src': src_padded,
        'tgt': tgt_padded
    }

train_dataloader, val_dataloader, vocab_src, vocab_tgt = create_dataloaders(batch_size=32)

In [None]:
model = Transformer(
    src_vocab_size=len(vocab_src),
    tgt_vocab_size=len(vocab_tgt),
    d_model=512,
    num_layers=6,
    num_heads=8,
    d_ff=2048,
    dropout=0.1
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
criterion = LabelSmoothing(smoothing=0.1).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1.0, betas=(0.9, 0.98), eps=1e-9)
scheduler = TransformerLRScheduler(optimizer, d_model=512, warmup_steps=4000)

losses = train_transformer(
    model=model,
    train_dataloader=train_dataloader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=10,
    device=device
)