In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import json
import numpy as np
from sklearn.metrics import f1_score
from tqdm import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line.strip()) for line in f]

train_data = load_jsonl("train_data.jsonl")
val_data = load_jsonl("val_data.jsonl")

all_tags = sorted(list({tag for d in train_data + val_data for tag in d["label_tags"]}))
tag2id = {tag: i for i, tag in enumerate(all_tags)}
id2tag = {i: tag for tag, i in tag2id.items()}

In [3]:
class PromptDataset(Dataset):
    def __init__(self, data, tokenizer, tag2id):
        self.inputs = []
        self.labels = []
        for d in data:
            text = ", ".join(d["context_tags"]) + " [SEP] " + d["target_character"]
            self.inputs.append(text)
            label_vec = [0] * len(tag2id)
            for tag in d["label_tags"]:
                if tag in tag2id:
                    label_vec[tag2id[tag]] = 1
            self.labels.append(label_vec)
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        encoded = self.tokenizer(
            self.inputs[idx],
            truncation=True,
            padding="max_length",
            max_length=64,
            return_tensors="pt"
        )
        return {
            "input_ids": encoded["input_ids"].squeeze(),
            "attention_mask": encoded["attention_mask"].squeeze(),
            "labels": torch.tensor(self.labels[idx], dtype=torch.float32)
        }

In [5]:
class BERTTagRecommender(nn.Module):
    def __init__(self, model_name, output_dim):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.classifier = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        return self.classifier(outputs.pooler_output)

In [None]:
def train_model(train_data, val_data, tag2id, num_epochs=50, resume=False):
    model_name = "sentence-transformers/all-MiniLM-L6-v2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    train_ds = PromptDataset(train_data, tokenizer, tag2id)
    val_ds = PromptDataset(val_data, tokenizer, tag2id)
    train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=32)

    model = BERTTagRecommender(model_name, len(tag2id))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    criterion = nn.BCEWithLogitsLoss()

    start_epoch = 0
    model_file = "checkpoint_model.pt"
    optim_file = "checkpoint_optim.pt"

    if resume and os.path.exists(model_file):
        model.load_state_dict(torch.load(model_file))
        optimizer.load_state_dict(torch.load(optim_file))
        print("✅ 已加载上次保存的模型与优化器状态。")

    for epoch in range(start_epoch, num_epochs):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for batch in progress_bar:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())

        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for batch in val_loader:
                input_ids = batch["input_ids"].to(device)
                attention_mask = batch["attention_mask"].to(device)
                labels = batch["labels"].cpu().numpy()
                logits = model(input_ids, attention_mask)
                probs = torch.sigmoid(logits).cpu().numpy()
                preds = (probs > 0.5).astype(int)
                all_preds.append(preds)
                all_labels.append(labels)

        all_preds = np.vstack(all_preds)
        all_labels = np.vstack(all_labels)
        micro_f1 = f1_score(all_labels, all_preds, average="micro", zero_division=0)
        macro_f1 = f1_score(all_labels, all_preds, average="macro", zero_division=0)
        print(f"✅ Epoch {epoch+1} | Loss: {total_loss:.4f} | Micro-F1: {micro_f1:.4f} | Macro-F1: {macro_f1:.4f}")

        torch.save(model.state_dict(), model_file)
        torch.save(optimizer.state_dict(), optim_file)
        print("✅ 模型已保存：checkpoint_model.pt")

    return model, tokenizer

In [None]:
trained_model, trained_tokenizer = train_model(
    train_data, val_data, tag2id,
    num_epochs=30,
    resume=True  
)

Epoch 1: 100%|██████████| 1931/1931 [00:26<00:00, 73.74it/s, loss=0.0731]


✅ Epoch 1 | Loss: 213.5485 | Micro-F1: 0.0000 | Macro-F1: 0.0000
✅ 模型已保存：checkpoint_model.pt


Epoch 2: 100%|██████████| 1931/1931 [00:26<00:00, 73.81it/s, loss=0.0751]


✅ Epoch 2 | Loss: 143.7601 | Micro-F1: 0.0000 | Macro-F1: 0.0000
✅ 模型已保存：checkpoint_model.pt


Epoch 3: 100%|██████████| 1931/1931 [00:25<00:00, 74.75it/s, loss=0.0708]


✅ Epoch 3 | Loss: 140.0060 | Micro-F1: 0.0003 | Macro-F1: 0.0000
✅ 模型已保存：checkpoint_model.pt


Epoch 4: 100%|██████████| 1931/1931 [00:25<00:00, 75.24it/s, loss=0.0705]


✅ Epoch 4 | Loss: 131.7316 | Micro-F1: 0.0322 | Macro-F1: 0.0015
✅ 模型已保存：checkpoint_model.pt


Epoch 5: 100%|██████████| 1931/1931 [00:25<00:00, 75.39it/s, loss=0.0545]


✅ Epoch 5 | Loss: 120.8631 | Micro-F1: 0.0941 | Macro-F1: 0.0073
✅ 模型已保存：checkpoint_model.pt


Epoch 6: 100%|██████████| 1931/1931 [00:25<00:00, 75.10it/s, loss=0.0456]


✅ Epoch 6 | Loss: 100.0581 | Micro-F1: 0.1742 | Macro-F1: 0.0292
✅ 模型已保存：checkpoint_model.pt


Epoch 7: 100%|██████████| 1931/1931 [00:25<00:00, 74.40it/s, loss=0.039] 


✅ Epoch 7 | Loss: 82.3788 | Micro-F1: 0.3733 | Macro-F1: 0.2843
✅ 模型已保存：checkpoint_model.pt


Epoch 8: 100%|██████████| 1931/1931 [00:25<00:00, 74.78it/s, loss=0.0439]


✅ Epoch 8 | Loss: 67.0728 | Micro-F1: 0.6032 | Macro-F1: 0.7193
✅ 模型已保存：checkpoint_model.pt


Epoch 9: 100%|██████████| 1931/1931 [00:25<00:00, 74.61it/s, loss=0.0169]


✅ Epoch 9 | Loss: 54.8396 | Micro-F1: 0.7307 | Macro-F1: 0.8606
✅ 模型已保存：checkpoint_model.pt


Epoch 10: 100%|██████████| 1931/1931 [00:25<00:00, 75.05it/s, loss=0.0151]


✅ Epoch 10 | Loss: 45.9179 | Micro-F1: 0.7998 | Macro-F1: 0.9030
✅ 模型已保存：checkpoint_model.pt


Epoch 11: 100%|██████████| 1931/1931 [00:25<00:00, 75.43it/s, loss=0.0202]


✅ Epoch 11 | Loss: 39.5459 | Micro-F1: 0.8359 | Macro-F1: 0.9181
✅ 模型已保存：checkpoint_model.pt


Epoch 12: 100%|██████████| 1931/1931 [00:25<00:00, 75.32it/s, loss=0.0275] 


✅ Epoch 12 | Loss: 34.9753 | Micro-F1: 0.8639 | Macro-F1: 0.9252
✅ 模型已保存：checkpoint_model.pt


Epoch 13: 100%|██████████| 1931/1931 [00:25<00:00, 75.04it/s, loss=0.0159] 


✅ Epoch 13 | Loss: 31.5951 | Micro-F1: 0.8801 | Macro-F1: 0.9277
✅ 模型已保存：checkpoint_model.pt


Epoch 14: 100%|██████████| 1931/1931 [00:25<00:00, 74.98it/s, loss=0.0175] 


✅ Epoch 14 | Loss: 29.1183 | Micro-F1: 0.8911 | Macro-F1: 0.9295
✅ 模型已保存：checkpoint_model.pt


Epoch 15: 100%|██████████| 1931/1931 [00:25<00:00, 75.29it/s, loss=0.00948]


✅ Epoch 15 | Loss: 27.1900 | Micro-F1: 0.8992 | Macro-F1: 0.9304
✅ 模型已保存：checkpoint_model.pt


Epoch 16: 100%|██████████| 1931/1931 [00:25<00:00, 75.18it/s, loss=0.00526]


✅ Epoch 16 | Loss: 25.6163 | Micro-F1: 0.9054 | Macro-F1: 0.9305
✅ 模型已保存：checkpoint_model.pt


Epoch 17: 100%|██████████| 1931/1931 [00:25<00:00, 75.59it/s, loss=0.0156] 


✅ Epoch 17 | Loss: 24.3403 | Micro-F1: 0.9097 | Macro-F1: 0.9314
✅ 模型已保存：checkpoint_model.pt


Epoch 18: 100%|██████████| 1931/1931 [00:25<00:00, 76.51it/s, loss=0.0131] 


✅ Epoch 18 | Loss: 23.2832 | Micro-F1: 0.9139 | Macro-F1: 0.9319
✅ 模型已保存：checkpoint_model.pt


Epoch 19: 100%|██████████| 1931/1931 [00:25<00:00, 74.78it/s, loss=0.0127] 


✅ Epoch 19 | Loss: 22.3899 | Micro-F1: 0.9160 | Macro-F1: 0.9321
✅ 模型已保存：checkpoint_model.pt


Epoch 20: 100%|██████████| 1931/1931 [00:25<00:00, 74.99it/s, loss=0.0115] 


✅ Epoch 20 | Loss: 21.6269 | Micro-F1: 0.9184 | Macro-F1: 0.9320
✅ 模型已保存：checkpoint_model.pt


Epoch 21: 100%|██████████| 1931/1931 [00:25<00:00, 75.31it/s, loss=0.012]  


✅ Epoch 21 | Loss: 21.0044 | Micro-F1: 0.9205 | Macro-F1: 0.9324
✅ 模型已保存：checkpoint_model.pt


Epoch 22: 100%|██████████| 1931/1931 [00:25<00:00, 74.97it/s, loss=0.00337]


✅ Epoch 22 | Loss: 20.4378 | Micro-F1: 0.9212 | Macro-F1: 0.9323
✅ 模型已保存：checkpoint_model.pt


Epoch 23: 100%|██████████| 1931/1931 [00:25<00:00, 75.19it/s, loss=0.00205]


✅ Epoch 23 | Loss: 19.9769 | Micro-F1: 0.9230 | Macro-F1: 0.9328
✅ 模型已保存：checkpoint_model.pt


Epoch 24: 100%|██████████| 1931/1931 [00:25<00:00, 76.15it/s, loss=0.0171] 


✅ Epoch 24 | Loss: 19.5475 | Micro-F1: 0.9242 | Macro-F1: 0.9328
✅ 模型已保存：checkpoint_model.pt


Epoch 25: 100%|██████████| 1931/1931 [00:25<00:00, 75.91it/s, loss=0.0123] 


✅ Epoch 25 | Loss: 19.1805 | Micro-F1: 0.9249 | Macro-F1: 0.9325
✅ 模型已保存：checkpoint_model.pt


Epoch 26: 100%|██████████| 1931/1931 [00:25<00:00, 75.68it/s, loss=0.00661]


✅ Epoch 26 | Loss: 18.8407 | Micro-F1: 0.9257 | Macro-F1: 0.9326
✅ 模型已保存：checkpoint_model.pt


Epoch 27: 100%|██████████| 1931/1931 [00:26<00:00, 72.66it/s, loss=0.00926]


✅ Epoch 27 | Loss: 18.5574 | Micro-F1: 0.9265 | Macro-F1: 0.9325
✅ 模型已保存：checkpoint_model.pt


Epoch 28: 100%|██████████| 1931/1931 [00:26<00:00, 72.50it/s, loss=0.00946] 


✅ Epoch 28 | Loss: 18.2824 | Micro-F1: 0.9275 | Macro-F1: 0.9329
✅ 模型已保存：checkpoint_model.pt


Epoch 29: 100%|██████████| 1931/1931 [00:26<00:00, 74.23it/s, loss=0.00172] 


✅ Epoch 29 | Loss: 18.0547 | Micro-F1: 0.9279 | Macro-F1: 0.9326
✅ 模型已保存：checkpoint_model.pt


Epoch 30: 100%|██████████| 1931/1931 [00:26<00:00, 72.58it/s, loss=0.00112] 


✅ Epoch 30 | Loss: 17.8195 | Micro-F1: 0.9286 | Macro-F1: 0.9327
✅ 模型已保存：checkpoint_model.pt
