In [1]:
!pip install -U datasets huggingface_hub

Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-4.0.0-py3-none-any.whl (494 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m494.8/494.8 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2025.3.0-py3-none-any.whl (193 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.6/193.6 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fsspec, datasets
  Attempting uninstall: fsspec
    Found existing installation: fsspec 2025.3.2
    Uninstalling fsspec-2025.3.2:
      Successfully uninstalled fsspec-2025.3.2
  Attempting uninstall: datasets
    Found existing installation: datasets 2.14.4
    Uninstalling datasets-2.14.4:
      Successfully uninstalled datasets-2.14.4
[31mERROR: pip's dependency r

In [2]:
!pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.5


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from collections import Counter
import random
import re
import datasets
import tqdm
import math
from functools import partial
import math
import argparse
import os
import collections
import json
import sentencepiece
import shutil
import copy
import multiprocessing
import transformers
from dataclasses import dataclass, field
from evaluate import load

# set "high" if you have a GPU with compute capability >= 8.0 else "highest"
torch.set_float32_matmul_precision("high")
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

# Training config

In [6]:
## you can modify some options such as batch_size, depending on your environments

training_config = {
    "batch_size": 4,
    "epochs": 1,
    "lr": 1e-4,
    "warmup_steps": 50,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "gradient_accumulate_steps": 1,
}

# Dataset load

In [7]:
dataset = datasets.load_dataset("lemon-mint/korean_english_parallel_wiki_augmented_v1",split="train",download_mode="force_redownload")
dataset = dataset.filter(lambda x: len(x['english']) < 8192 and len(x['english']) > 128 and len(x['korean']) < 8192 and len(x['korean']) > 128)
valid_set = dataset.select(range(10000))
train_set = dataset.select(range(10000, 110000))

data/train-00000-of-00002.parquet:   0%|          | 0.00/285M [00:00<?, ?B/s]

data/train-00001-of-00002.parquet:   0%|          | 0.00/285M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/503245 [00:00<?, ? examples/s]

Filter:   0%|          | 0/503245 [00:00<?, ? examples/s]

In [8]:
print(dataset[0])

{'english': "5059 aluminum alloy is an aluminum-magnesium alloy, primarily alloyed with magnesium. It is not strengthened by heat treatment, instead becoming stronger due to strain hardening, or cold mechanical working of the material.\n\nSince heat treatment doesn't strongly affect the strength, 5059 can be readily welded and retain most of its mechanical strength.\n\n5059 alloy was derived from closely related 5083 aluminum alloy by researchers at Corus Aluminium in 1999.", 'korean': '5059 알루미늄 합금은 주로 마그네슘으로 합금된 알루미늄-마그네슘 합금입니다. 열처리로 강화되지 않고, 대신 재료의 변형 경화 또는 냉간 기계 가공으로 강해집니다.\n\n열처리가 강도에 큰 영향을 미치지 않기 때문에 5059는 용접이 용이하고 기계적 강도를 대부분 유지할 수 있습니다.\n\n5059 합금은 1999년 코러스 알루미늄의 연구원들에 의해 밀접하게 관련된 5083 알루미늄 합금에서 유래했습니다.', 'score': 0.9080972430930964}


In [9]:
tokenizer = transformers.AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ko-en")
additional_special_tokens = {}
if tokenizer.pad_token is None:
    additional_special_tokens["pad_token"] = "<pad>"
if tokenizer.eos_token is None:
    additional_special_tokens["eos_token"] = "</s>"
if tokenizer.bos_token is None:
    additional_special_tokens["bos_token"] = "<s>"
tokenizer.add_special_tokens(additional_special_tokens)

def collate_fn(batch):
    english_corpus = [item["english"] for item in batch]
    korean_corpus = [item["korean"] for item in batch]
    english_corpus = tokenizer(english_corpus, padding=True, truncation=True, return_tensors="pt", max_length=512, pad_to_multiple_of=64)
    korean_corpus = tokenizer(korean_corpus, padding=True, truncation=True, return_tensors="pt", max_length=512, pad_to_multiple_of=64)
    labels = korean_corpus["input_ids"].clone()
    labels[korean_corpus['attention_mask'].eq(0)] = -100


    return {
        "encoder_input_ids": english_corpus["input_ids"],
        "encoder_attention_mask": english_corpus["attention_mask"],
        "decoder_input_ids": korean_corpus["input_ids"],
        "labels": korean_corpus["input_ids"],
    }



tokenizer_config.json:   0%|          | 0.00/44.0 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

source.spm:   0%|          | 0.00/842k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/813k [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]



### Model implement

In [28]:
@dataclass
class ModelConfig(object):
    vocab_size: int = field(default=50000)
    encoder_hidden_dim: int = field(default=512) # hidden dimention of encoder lstm
    decoder_hidden_dim: int = field(default=512) # hidden dimention of decoder lstm
    hidden_dim: int = field(default=512) # hidden dimention of other module like attention
    embed_dim: int = field(default=512) # embedding dimention
    pad_idx: int = field(default=0)
    sos_idx: int = field(default=1)
    eos_idx: int = field(default=2)
    n_layers: int = field(default=1)
    dropout: float = field(default=0.1)

    attention_type:str = field(default="global")
    window_size: int = field(default=10)
    sigma_ratio: float = field(default=2.0)

    do_input_feeding: bool = field(default=True)

class GlobalAttention(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        self.query_proj = nn.Linear(config.decoder_hidden_dim, config.hidden_dim, bias=False)
        self.key_proj = nn.Linear(config.encoder_hidden_dim * 2, config.hidden_dim, bias=False)
        self.value_proj = nn.Linear(config.encoder_hidden_dim * 2, config.hidden_dim, bias=False)
        self.output_proj = nn.Linear(config.hidden_dim, config.decoder_hidden_dim, bias=False)

        self.dropout = nn.Dropout(config.dropout)
        self.scale = np.sqrt(config.hidden_dim)

    def forward(self, decoder_hidden_query, encoder_outputs, encoder_attention_mask):
        query = self.query_proj(decoder_hidden_query)
        key = self.key_proj(encoder_outputs)
        value = self.value_proj(encoder_outputs)

        # fill here for global attention forward
        # shape hint:
        # context: (batch, 1, hidden_dim)
        ######

        ## YOUR CODES
        # 1. reshape query
        query = self.query_proj(decoder_hidden_query)
        if query.dim() == 2:
            query = query.unsqueeze(1)
        print("decoder_hidden_query.shape:", decoder_hidden_query.shape)
        print("query.shape:", query.shape)




        # 2. attention score
        score = torch.bmm(query, key.transpose(1, 2))           # (batch, 1, src_len)
        score = score / self.scale

        if encoder_attention_mask is not None:
            # encoder_attention_mask: (B, src_len)
            score = score.masked_fill(encoder_attention_mask.unsqueeze(1) == 0, float('-inf'))

                                      # 4. softmax + dropout
        attn_weights = F.softmax(score, dim=-1)                 # (batch, 1, src_len)
        attn_weights = self.dropout(attn_weights)

        # 5. weighted sum of value
        context = torch.bmm(attn_weights, value)                # (batch, 1, hidden_dim)


        ######
        output_context = self.output_proj(context)

        return output_context

class LocalAttention(GlobalAttention):
    def __init__(self, config: ModelConfig):
        super().__init__(config)
        self.window_size = config.window_size
        self.location_proj_up = nn.Linear(config.decoder_hidden_dim, config.hidden_dim, bias=False)
        self.location_proj_down = nn.Linear(config.hidden_dim, 1, bias=False)
        self.sigma = self.window_size / config.sigma_ratio

    def forward(self, decoder_hidden_query, encoder_outputs, encoder_attention_mask):
        key, value, attn_mask, gaussian_penalty = self._gather_local_context(decoder_hidden_query, encoder_outputs, encoder_attention_mask)
        query = self.query_proj(decoder_hidden_query)
        key = self.key_proj(key)
        value = self.value_proj(value)

        # fill here for local attention forward
        # shape hint:
        # context: (batch, 1, hidden_dim)
        ######
        ## YOUR CODES

        # (B, H) → (B, 1, H) for bmm
        query = self.query_proj(decoder_hidden_query)
        query = query.view(query.size(0), 1, -1)  # 확실하게 (B, 1, H)
        print("decoder_hidden_query.shape:", decoder_hidden_query.shape)
        print("query.shape:", query.shape)



        # dot-product attention: (B, 1, H) @ (B, H, S_local) → (B, 1, S_local)
        score = torch.bmm(query, key.transpose(1, 2))
        score = score / self.scale

        # 마스크가 있다면 (-inf로 채워서 softmax에서 무시)
        score = score.masked_fill(attn_mask.unsqueeze(1) == 0, float('-inf'))

        # Gaussian penalty 추가 (penalty는 0~1 → log space로 더하는 느낌)
        score = score + gaussian_penalty.unsqueeze(1)

        # softmax로 attention weights 계산: (B, 1, S_local)
        attn_weights = F.softmax(score, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # attention-weighted sum으로 context 생성: (B, 1, H)
        context = torch.bmm(attn_weights, value)

        ######
        output_context = self.output_proj(context)

        return output_context

    def _gather_local_context(self, decoder_hidden_query, encoder_outputs, encoder_attention_mask):
        device = encoder_outputs.device
        src_len = encoder_attention_mask.sum(dim=-1).unsqueeze(-1)

        # fill here for local context window
        # shape hint:
        # local_key: (batch, window_size * 2 + 1, hidden_dim)
        # local_value: (batch, window_size * 2 + 1, hidden_dim)
        # local_attn_mask: (batch, window_size * 2 + 1)
        # gaussian_penalty: (batch, window_size * 2 + 1)
        ######

        ## YOUR CODES
        # decoder로부터 현재 시점의 중심 위치 예측
        batch_size, src_seq_len, _ = encoder_outputs.size()

        hidden = self.location_proj_up(decoder_hidden_query)              # (B, hidden_dim)
        hidden = torch.tanh(hidden)
        center_pos = self.location_proj_down(hidden).squeeze(-1)          # (B,)

        # center_pos를 encoder 범위 안으로 제한
        center_pos = center_pos.clamp(min=0, max=src_seq_len - 1)

        # 정수 인덱스로 변환
        center_pos = center_pos.round().long()  # (B,)

        # 슬라이싱 window
        window = self.window_size
        idxs = torch.arange(-window, window + 1, device=device).view(1, -1)  # (1, 2w+1)
        # (B, 2w+1)
        local_idxs = center_pos.unsqueeze(1) + idxs  # 각 배치별 중심 기준 window 인덱스

        # 인덱스가 0~src_seq_len 사이로만 나오도록 clamp
        local_idxs = local_idxs.clamp(0, src_seq_len - 1)  # (B, 2w+1)

        # batch gather
        batch_idxs = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(local_idxs)

        # gather local encoder outputs
        local_key = encoder_outputs[batch_idxs, local_idxs]   # (B, 2w+1, hidden)
        local_value = local_key.clone()

        local_attn_mask = encoder_attention_mask[batch_idxs, local_idxs]  # (B, 2w+1)

        # Gaussian Penalty
        relative_pos = idxs.expand_as(local_idxs).float()  # (B, 2w+1)
        gaussian_penalty = - (relative_pos ** 2) / (2 * (self.sigma ** 2))  # (B, 2w+1)


        ######

        return local_key, local_value, local_attn_mask, gaussian_penalty

class Encoder(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        self.encoder = nn.LSTM(
            input_size=config.embed_dim,
            hidden_size=config.encoder_hidden_dim,
            num_layers=config.n_layers,
            dropout=config.dropout if config.n_layers > 1 else 0,
            bidirectional=True,
            batch_first=True
        )

        self.h_dec_proj = nn.Linear(config.encoder_hidden_dim * 2, config.decoder_hidden_dim)
        self.c_dec_proj = nn.Linear(config.encoder_hidden_dim * 2, config.decoder_hidden_dim)

    def forward(self, input_embeds, attention_mask):

        # Fill here for encoder forward
        # shape hint
        # input_embeds: (batch, src_seq_len, embed_dim)
        # attention_mask: (batch, src_seq_len)
        # encoder_output: (batch, src_seq_len, hidden_dim)
        # h_enc: (n_layers, batch, decoder_hidden_dim)
        # c_enc: (n_layers, batch, decoder_hidden_dim)
        # hint for implementation
        # 1. use nn.utils.rnn.pack_padded_sequence to packing inputs for rnn series, see https://pytorch.org/docs/stable/generated/torch.nn.utils.rnn.pack_padded_sequence.html
        #    failure to properly handle padding will result in a penalty.
        # 2. lstm cell state and hidden state will be doubled because of bidirectional lstm.
        #    decoder will be unidirectional for causal language modeling.
        #    handle the hidden state and cell state to be same as decoder.
        ######

        ## YOUR CODES
        #attention_mask = (input_embeds.abs().sum(dim=-1) != 0).long()
        lengths = attention_mask.sum(dim=1).cpu()

        # 2. 패킹: RNN이 불필요한 PAD 토큰을 계산하지 않게 하기 위해 실제 데이터 길이만큼 계산하도록 만들어줌
        packed_input = nn.utils.rnn.pack_padded_sequence(
            input_embeds,
            lengths,
            batch_first=True,
            enforce_sorted=False
        )

        # 3. LSTM 통과
        packed_output, (h, c) = self.encoder(packed_input)

        # 4. 다시 unpack
        encoder_output, _ = nn.utils.rnn.pad_packed_sequence(
            packed_output, batch_first=True
        )  # shape: (batch, src_seq_len, hidden_dim * 2)
        # 5. Bidirectional LSTM → decoder용 hidden state로 변환
        # h, c: (num_layers * 2, batch, encoder_hidden_dim) → 2 방향 concat

        batch_size = input_embeds.size(0)
        n_layers = self.config.n_layers

        # (n_layers * 2, batch, hidden_dim) -> (n_layers, 2, batch, hidden_dim)
        h_final = h.view(n_layers, 2, batch_size, self.config.encoder_hidden_dim)
        c_final = c.view(n_layers, 2, batch_size, self.config.encoder_hidden_dim)

        h_combined = torch.cat([h_final[:, 0, :, :], h_final[:, 1, :, :]], dim=2)
        c_combined = torch.cat([c_final[:, 0, :, :], c_final[:, 1, :, :]], dim=2)

        # 6. decoder에 맞게 projection
        h_enc = self.h_dec_proj(h_combined)
        c_enc = self.c_dec_proj(c_combined)

        ######

        return encoder_output, (h_enc, c_enc)

class Decoder(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        self.decoder = nn.LSTM(
            input_size=config.embed_dim + config.hidden_dim if config.do_input_feeding else config.embed_dim,
            hidden_size=config.decoder_hidden_dim,
            num_layers=config.n_layers,
            dropout=config.dropout if config.n_layers > 1 else 0,
            batch_first=True
        )
        match config.attention_type:
            case "local":
                self.attention = LocalAttention(config)
            case "global":
                self.attention = GlobalAttention(config)
            case _:
                raise ValueError(f"Unknown attention type: {config.attention_type}")
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, input_embeds, encoder_outputs, h_enc, c_enc, attention_mask):
        decoder_output, (h_dec, c_dec) = self.decoder(input_embeds, (h_enc, c_enc))
        attention_context = self.attention(decoder_output, encoder_outputs, attention_mask)
        decoder_output = decoder_output + attention_context

        return decoder_output, attention_context, (h_dec, c_dec)

class Seq2Seq(nn.Module):
    def __init__(self, config: ModelConfig):
        super().__init__()
        self.config = config

        self.embedding = nn.Embedding(config.vocab_size, config.embed_dim, padding_idx=config.pad_idx)

        self.encoder = Encoder(config)
        self.decoder = Decoder(config)

        self.lm_head = nn.Linear(config.hidden_dim, config.vocab_size)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids, labels=None, cache=None):
        if cache is None:
            encoder_input_embeds = self.embedding(encoder_input_ids)
            encoder_outputs, (h_enc, c_enc) = self.encoder(encoder_input_embeds, encoder_attention_mask)

            current_h_dec, current_c_dec = h_enc, c_enc
            prev_attn_context = None
        else:
            encoder_outputs, current_h_dec, current_c_dec, prev_attn_context = cache

        batch_size, tgt_len = decoder_input_ids.shape
        decoder_input_embeds = self.embedding(decoder_input_ids)

        if prev_attn_context is None:
            prev_attn_context = torch.zeros((batch_size, 1, self.config.decoder_hidden_dim)).to(decoder_input_embeds)

        outputs = []

        for t in range(tgt_len):
            # fill here for decoder forward
            ######

            ## YOUR CODES
            if self.config.do_input_feeding:
                rnn_input = torch.cat([decoder_input_embeds[:, t:t+1], prev_attn_context], dim=-1)
            else:
                rnn_input = decoder_input_embeds[:, t:t+1]

            decoder_output, attn_context, (current_h_dec, current_c_dec) = self.decoder(
                rnn_input, encoder_outputs, current_h_dec, current_c_dec, encoder_attention_mask
            )

            prev_attn_context = attn_context  # 다음 timestep에 전달


            ######
            outputs.append(decoder_output)


        outputs = torch.cat(outputs, dim=1)

        lm_logits = self.lm_head(outputs)

        loss = None
        if labels is not None:
            # for cross entropy loss
            # loss must be scalar

            labels_for_loss = labels[:, 1:].contiguous()
            lm_logits_for_loss = lm_logits[:, :-1, :].contiguous()
            loss = F.cross_entropy(lm_logits_for_loss.view(-1, self.config.vocab_size), labels_for_loss.view(-1))

            return loss
        else:
            return lm_logits, (encoder_outputs, current_h_dec, current_c_dec, prev_attn_context)

    @torch.no_grad()
    def generate(
        self,
        encoder_input_ids: torch.LongTensor,
        encoder_attention_mask: torch.LongTensor,
        max_new_tokens: int = 256,
    ):
        batch_size, _ = encoder_input_ids.shape
        device = encoder_input_ids.device
        eos = self.config.eos_idx

        unfinish_flag = torch.ones(batch_size, dtype=torch.long, device=device)
        cache = None
        decoder_input_ids = torch.full((batch_size, 1), self.config.sos_idx, dtype=torch.long, device=device)

        for _ in range(max_new_tokens):
            # fill here for causal generation
           ######

            ## YOUR CODES
            # decoder input의 임베딩
            decoder_input_embeds = self.embedding(decoder_input_ids[:, -1:])  # 마지막 토큰만

            if cache is None:
                encoder_input_embeds = self.embedding(encoder_input_ids)
                encoder_outputs, (h_enc, c_enc) = self.encoder(encoder_input_embeds, encoder_attention_mask)
                current_h_dec, current_c_dec = h_enc, c_enc
                prev_attn_context = torch.zeros((batch_size, 1, self.config.decoder_hidden_dim), device=device)
            else:
                encoder_outputs, current_h_dec, current_c_dec, prev_attn_context = cache

            if self.config.do_input_feeding:
                rnn_input = torch.cat([decoder_input_embeds, prev_attn_context], dim=-1)
            else:
                rnn_input = decoder_input_embeds

            decoder_output, attn_context, (current_h_dec, current_c_dec) = self.decoder(
                rnn_input, encoder_outputs, current_h_dec, current_c_dec, encoder_attention_mask
            )

            logits = self.lm_head(decoder_output)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)

            next_token = next_token * unfinish_flag.unsqueeze(1) + eos * (1 - unfinish_flag).unsqueeze(1)
            decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=1)
            unfinish_flag = unfinish_flag * (next_token != eos).squeeze(1)

            prev_attn_context = attn_context
            cache = (encoder_outputs, current_h_dec, current_c_dec, prev_attn_context)

            if unfinish_flag.max() == 0:
                break


            ######
        return decoder_input_ids


In [29]:
def train(model, train_dataset, valid_dataset, collate_fn, train_args, prefix):
    optimizer = optim.Adam(model.parameters(), lr=train_args["lr"])

    train_dataloader = DataLoader(train_dataset, batch_size=train_args['batch_size'], shuffle=True, collate_fn=collate_fn, num_workers=os.cpu_count())
    valid_dataloader = DataLoader(valid_dataset, batch_size=train_args['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=os.cpu_count())

    total_steps = len(train_dataloader) * train_args['epochs']

    num_training_steps = train_args['epochs'] * (len(train_dataloader) // train_args['gradient_accumulate_steps'])
    scheduler = transformers.get_scheduler(
        name="cosine",
        optimizer=optimizer,
        num_warmup_steps=train_args['warmup_steps'],
        num_training_steps=num_training_steps
    )

    best_loss = 987654321
    optimizer.zero_grad()

    output_path = os.path.join("output", prefix)
    os.makedirs(output_path, exist_ok=True)
    with open(os.path.join(output_path, "train_args.json"), "w") as f:
        json.dump(train_args, f)

    pbar = tqdm.tqdm(total=total_steps, desc="training")
    for epoch in range(train_args['epochs']):
        pbar.set_description(f"Epoch {epoch+1}/{train_args['epochs']}")
        move_avg_loss = []
        model.train()
        for i, batch in enumerate(train_dataloader):
            batch = {k:v.to(train_args['device']) if isinstance(v,torch.Tensor) else v for k,v in batch.items()}

            loss = model(**batch)
            loss = loss / train_args['gradient_accumulate_steps']
            if loss.size() != torch.Size([]):
                loss = loss.mean()
            loss.backward()

            if (i+1) % train_args['gradient_accumulate_steps'] == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

            move_avg_loss.append(loss.item())
            if len(move_avg_loss) > 100: move_avg_loss.pop(0)
            pbar.set_postfix_str(f"loss: {sum(move_avg_loss)/len(move_avg_loss):.04f} lr: {optimizer.param_groups[0]['lr']:.2e}")
            pbar.update(1)

        model.eval()
        with torch.no_grad():
            eval_loss = 0
            for i, batch in enumerate(valid_dataloader):
                batch = {k:v.to(train_args['device']) if isinstance(v,torch.Tensor) else v for k,v in batch.items()}
                loss_val = model(**batch)
                if loss_val.size() != torch.Size([]):
                    loss_val = loss_val.mean()
                eval_loss += loss_val.item()
                pbar.set_postfix_str(f"val_loss: {eval_loss / (i+1):.04f}")
        eval_loss /= len(valid_dataloader)
        pbar.write(f"Validation Loss: {eval_loss:.04f}")

        if eval_loss < best_loss:
            best_loss = eval_loss

            torch.save(model.state_dict(), os.path.join(output_path,"best_model.pth"))
            pbar.write(f"Model Saved best loss: {best_loss:.04f}")

    pbar.close()

def evaluate(model, dataset, tokenizer, collate_fn, train_args):
    model.eval()
    dataloader = DataLoader(dataset, batch_size=train_args['batch_size'], shuffle=False, collate_fn=collate_fn, num_workers=os.cpu_count())

    answers = []
    predicts = []
    for i, batch in enumerate(tqdm.tqdm(dataloader, desc="Evaluating")):
        batch = {k:v.to(train_args['device']) if isinstance(v,torch.Tensor) else v for k,v in batch.items()}
        gen_output = model.generate(
            encoder_input_ids=batch["encoder_input_ids"],
            encoder_attention_mask=batch["encoder_attention_mask"],
            max_new_tokens=512
        )
        pred = tokenizer.batch_decode(gen_output, skip_special_tokens=True)
        ans = tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)
        answers.extend(ans)
        predicts.extend(pred)

    bleu = load("bleu")
    result = bleu.compute(predictions=predicts, references=answers)
    print(f"BLEU: {result['bleu']:.4f}")

In [None]:
config = ModelConfig(
    vocab_size=len(tokenizer),
    pad_idx=tokenizer.pad_token_id,
    sos_idx=tokenizer.bos_token_id,
    eos_idx=tokenizer.eos_token_id,
    n_layers=2,
    dropout=0.1,

    attention_type="global",
    do_input_feeding=False,
)

model = Seq2Seq(config).to(training_config["device"])
model = model.to(torch.bfloat16)
model.compile()
print(model)

train(
    model,
    train_set,
    valid_set,
    collate_fn,
    training_config,
    prefix="seq2seq_global_attention_no_input_feeding"
)

model.load_state_dict(torch.load(os.path.join("output", "seq2seq_global_attention_no_input_feeding", "best_model.pth")))
evaluate(
    model,
    valid_set,
    tokenizer,
    collate_fn,
    training_config
)

del model
torch.cuda.empty_cache()

Seq2Seq(
  (embedding): Embedding(65002, 512, padding_idx=65000)
  (encoder): Encoder(
    (encoder): LSTM(512, 512, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
    (h_dec_proj): Linear(in_features=1024, out_features=512, bias=True)
    (c_dec_proj): Linear(in_features=1024, out_features=512, bias=True)
  )
  (decoder): Decoder(
    (decoder): LSTM(512, 512, num_layers=2, batch_first=True, dropout=0.1)
    (attention): GlobalAttention(
      (query_proj): Linear(in_features=512, out_features=512, bias=False)
      (key_proj): Linear(in_features=1024, out_features=512, bias=False)
      (value_proj): Linear(in_features=1024, out_features=512, bias=False)
      (output_proj): Linear(in_features=512, out_features=512, bias=False)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (lm_head): Linear(in_features=512, out_features=65002, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)









training:   0%|          | 0/25000 [00:00<?, ?it/s][A[A[A[A[A[A[A






Epoch 1/1:   0%|          | 0/25000 [00:00<?, ?it/s][A[A[A[A[A[A[A

decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.shape: torch.Size([4, 1, 512])
query.shape: torch.Size([4, 1, 512])
decoder_hidden_query.

In [None]:
config = ModelConfig(
    vocab_size=len(tokenizer),
    pad_idx=tokenizer.pad_token_id,
    sos_idx=tokenizer.bos_token_id,
    eos_idx=tokenizer.eos_token_id,
    n_layers=2,
    dropout=0.1,

    attention_type="global",
)

model = Seq2Seq(config).to(training_config["device"])
model = model.to(torch.bfloat16)
model.compile()
print(model)

train(
    model,
    train_set,
    valid_set,
    collate_fn,
    training_config,
    prefix="seq2seq_global_attention"
)

model.load_state_dict(torch.load(os.path.join("output", "seq2seq_global_attention", "best_model.pth")))
evaluate(
    model,
    valid_set,
    tokenizer,
    collate_fn,
    training_config
)

del model
torch.cuda.empty_cache()

In [None]:
config = ModelConfig(
    vocab_size=len(tokenizer),
    pad_idx=tokenizer.pad_token_id,
    sos_idx=tokenizer.bos_token_id,
    eos_idx=tokenizer.eos_token_id,
    n_layers=2,
    dropout=0.1,

    attention_type="local",
)

model = Seq2Seq(config).to(training_config["device"])
model = model.to(torch.bfloat16)
model.compile()
print(model)

train(
    model,
    train_set,
    valid_set,
    collate_fn,
    training_config,
    prefix="seq2seq_local_attention"
)

model.load_state_dict(torch.load(os.path.join("output", "seq2seq_local_attention", "best_model.pth")))
evaluate(
    model,
    valid_set,
    tokenizer,
    collate_fn,
    training_config
)

del model
torch.cuda.empty_cache()