微调一个 Marian 翻译模型进行汉英翻译


# 准备数据


### 构建数据集


In [1]:
from torch.utils.data import Dataset, random_split
import json

max_dataset_size = 22000
train_set_size = 20000
valid_set_size = 2000


class TRANS(Dataset):
    def __init__(self, data_file):
        self.data = self.load_data(data_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def load_data(self, data_file):
        data = {}
        with open(data_file, "rt") as f:
            for idx, line in enumerate(f):
                if idx == max_dataset_size:
                    break
                sample = json.loads(line.strip())
                data[idx] = sample
        return data


dataset = TRANS("translation2019zh/translation2019zh_train.json")
train_dataset, valid_dataset = random_split(dataset, [train_set_size, valid_set_size])
test_dataset = TRANS("translation2019zh/translation2019zh_valid.json")

test_dataset[0]

{'english': 'Slowly and not without struggle, America began to listen.',
 'chinese': '美国缓慢地开始倾听，但并非没有艰难曲折。'}

### 数据预处理

将每一个 batch 中的数据处理为该模型可接受的格式：一个包含 'attention_mask'、'input_ids'、'labels' 和 'decoder_input_ids' 键的字典。


默认情况下分词器会采用源语言的设定来编码文本，要编码目标语言则需要通过上下文管理器 as_target_tokenizer()：


对于翻译任务，标签序列就是目标语言的 token ID 序列。与序列标注任务类似，在模型预测出的标签序列与答案标签序列之间计算损失来调整模型参数，因此同样需要将填充的 pad 字符设置为 -100，以便在使用交叉熵计算序列损失时将它们忽略：


解码器通常需要两种类型的输入：

Encoder 的输出：这些是编码器处理输入数据后的一系列隐藏状态，它们为解码器提供了有关输入序列的上下文信息。

decoder input IDs 是标签序列的移位，在序列的开始位置增加了一个特殊的“序列起始符”。
在训练过程中，模型会基于 decoder input IDs 和 attention mask 来确保在预测某个 token 时不会使用到该 token 及其之后的 token 的信息。考虑到不同模型的移位操作可能存在差异，我们通过模型自带的 `prepare_decoder_input_ids_from_labels` 函数来完成。


In [2]:
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model_checkpoint = "Helsinki-NLP/opus-mt-zh-en"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

max_input_length = 128
max_target_length = 128

device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
model = model.to(device)


def collote_fn(batch_samples):
    batch_inputs, batch_targets = [], []
    for sample in batch_samples:
        batch_inputs.append(sample["chinese"])
        batch_targets.append(sample["english"])
    batch_data = tokenizer(
        batch_inputs,
        padding=True,
        max_length=max_input_length,
        truncation=True,
        return_tensors="pt",
    )
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            batch_targets,
            padding=True,
            max_length=max_target_length,
            truncation=True,
            return_tensors="pt",
        )["input_ids"]
        batch_data["decoder_input_ids"] = model.prepare_decoder_input_ids_from_labels(
            labels
        )
        end_token_index = torch.where(labels == tokenizer.eos_token_id)[1]
        for idx, end_idx in enumerate(end_token_index):
            labels[idx][end_idx + 1 :] = -100
        batch_data["labels"] = labels
    return batch_data


train_dataloader = DataLoader(
    train_dataset, batch_size=32, shuffle=True, collate_fn=collote_fn
)
valid_dataloader = DataLoader(
    valid_dataset, batch_size=32, shuffle=False, collate_fn=collote_fn
)

batch = next(iter(train_dataloader))
print(batch.keys())



dict_keys(['input_ids', 'attention_mask', 'decoder_input_ids', 'labels'])




# 训练模型


### 优化模型参数


In [3]:
def train_loop(data_loader, model, optimizer, lr_scheduler, epoch, total_loss):
    model.train()
    for batch, batch_data in enumerate(data_loader, start=1):
        batch_data = batch_data.to(device)
        outputs = model(**batch_data)
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        lr_scheduler.step()

        total_loss += loss.item()

    return total_loss

验证/测试循环负责评估模型的性能。使用 BLEU 用于度量两个词语序列之间的一致性。


In [4]:
from sacrebleu.metrics import BLEU
import numpy as np
import tqdm.auto as tqdm

bleu = BLEU()


def test_loop(dataloader, model):
    preds, labels = [], []

    model.eval()
    for batch_data in tqdm(dataloader):
        batch_data = batch_data.to(device)
        with torch.no_grad():
            generated_tokens = (
                model.generate(
                    batch_data["input_ids"],
                    attention_mask=batch_data["attention_mask"],
                    max_length=max_target_length,
                )
                .cpu()
                .numpy()
            )

        decoded_preds = tokenizer.batch_decode(
            generated_tokens, skip_special_tokens=True
        )

        label_tokens = batch_data["labels"].cpu().numpy()
        label_tokens = np.where(
            label_tokens != -100, label_tokens, tokenizer.pad_token_id
        )  # 将标签序列中的 -100 替换为 pad token ID 以便于分词器解码
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        preds += [pred.strip() for pred in decoded_preds]
        labels += [[label.strip()] for label in decoded_labels]
    bleu_score = bleu.corpus_score(preds, labels).score
    print(f"BLEU: {bleu_score:>0.2f}\n")
    return bleu_score

在开始训练之前，先评估一下没有微调的模型在测试集上的性能。


In [None]:
test_data = TRANS("translation2019zh/translation2019zh_valid.json")
test_dataloader = DataLoader(
    test_data, batch_size=32, shuffle=False, collate_fn=collote_fn
)
batch = next(iter(test_dataloader))
print(batch.keys())
print("batch shape:", {k: v.shape for k, v in batch.items()})
print(batch)

test_loop(test_dataloader, model)

### 保存模型


In [None]:
from transformers import AdamW, get_scheduler

learning_rate = 2e-5
epoch_num = 3

optimizer = AdamW(model.parameters(), lr=learning_rate)
lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=epoch_num * len(train_dataloader),
)
total_loss = 0.0
best_bleu = 0.0

for t in range(epoch_num):
    total_loss = train_loop(
        train_dataloader, model, optimizer, lr_scheduler, t + 1, total_loss
    )
    valid_bleu = test_loop(valid_dataloader, model)
    if valid_bleu > best_bleu:
        best_bleu = valid_bleu
        torch.save(
            model.state_dict(),
            f"epoch_{t+1}_valid_bleu_{valid_bleu:0.2f}_model_weights.bin",
        )

# 3. 测试模型


In [None]:
test_data = TRANS("translation2019zh/translation2019zh_valid.json")
test_dataloader = DataLoader(
    test_data, batch_size=32, shuffle=False, collate_fn=collote_fn
)

import json

model.load_state_dict(torch.load("epoch_1_valid_bleu_53.38_model_weights.bin"))

model.eval()
with torch.no_grad():
    print("evaluating on test set...")
    sources, preds, labels = [], [], []
    for batch_data in tqdm(test_dataloader):
        batch_data = batch_data.to(device)
        generated_tokens = (
            model.generate(
                batch_data["input_ids"],
                attention_mask=batch_data["attention_mask"],
                max_length=max_target_length,
            )
            .cpu()
            .numpy()
        )
        label_tokens = batch_data["labels"].cpu().numpy()

        decoded_sources = tokenizer.batch_decode(
            batch_data["input_ids"].cpu().numpy(),
            skip_special_tokens=True,
            use_source_tokenizer=True,
        )
        decoded_preds = tokenizer.batch_decode(
            generated_tokens, skip_special_tokens=True
        )
        label_tokens = np.where(
            label_tokens != -100, label_tokens, tokenizer.pad_token_id
        )
        decoded_labels = tokenizer.batch_decode(label_tokens, skip_special_tokens=True)

        sources += [source.strip() for source in decoded_sources]
        preds += [pred.strip() for pred in decoded_preds]
        labels += [[label.strip()] for label in decoded_labels]
    bleu_score = bleu.corpus_score(preds, labels).score
    print(f"Test BLEU: {bleu_score:>0.2f}\n")
    results = []
    print("saving predicted results...")
    for source, pred, label in zip(sources, preds, labels):
        results.append(
            {"sentence": source, "prediction": pred, "translation": label[0]}
        )
    with open("test_data_pred.json", "wt", encoding="utf-8") as f:
        for exapmle_result in results:
            f.write(json.dumps(exapmle_result, ensure_ascii=False) + "\n")