In [13]:
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from check import check

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

对于位置 pos 和维度 i：

- 当 i 是偶数时（从0开始）：
  $$
  PE(pos, i) = \sin\left(\frac{pos}{10000^{i/d_{\text{model}}}}\right)
  $$
  
- 当 i 是奇数时：
  $$
  PE(pos, i) = \cos\left(\frac{pos}{10000^{(i-1)/d_{\text{model}}}}\right)
  $$

其中：

- $pos$ 是序列中的位置

- $d_{\text{model}}$ 是模型的嵌入维度

- $i$ 是维度索引

In [14]:
# 位置编码（使模型能够利用序列顺序信息）
class posEncoder(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # 编码矩阵
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    # 将位置编码加到输入 x 上
    def forward(self, x):
        x = x + self.pe[:, : x.size(1)]     # 形状: [batch_size, seq_len, d_model]
        return x

In [15]:
class CRF_TF(nn.Module):
    def __init__(self, vocab_size, tag2idx, d_model=512, nhead=8, n_layers=6):
        super().__init__()
        self.tag2idx = tag2idx
        self.idx2tag = {v: k for k, v in tag2idx.items()}
        self.tagset_size = len(tag2idx)
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
        self.init_transitions()

        # transformer
        self.embedding = nn.Embedding(vocab_size, d_model)                              # 词嵌入层：将词ID映射为d_model维的向量
        self.pos_encoder = posEncoder(d_model)                                          # 位置编码层：为每个词加上位置信息
        encoder_layers = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True)   # Transformer编码器层
        self.transformer = nn.TransformerEncoder(encoder_layers, n_layers)

        self.hidden2tag = nn.Linear(d_model, self.tagset_size)                          # 线性层：将Transformer输出映射到标签空间

    def save_model(model, path):
        torch.save(model.state_dict(), path)
    
    def load_model(model, path, device):
        model.load_state_dict(torch.load(path, map_location=device))
        model.to(device)
        model.eval()

    def forward(self, x, mask):
        embeds = self.embedding(x)  # (batch, seq_len, d_model)
        embeds = self.pos_encoder(embeds)
        transformer_out = self.transformer(embeds, src_key_padding_mask=~mask)
        return self.hidden2tag(transformer_out)

    # 随机初始化转移矩阵
    def init_transitions(self):
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))

        # 特殊标签约束（保证在解码和训练时不会生成非法的标签转移）
        self.transitions.data[self.tag2idx["<START>"], :] = -10000
        self.transitions.data[:, self.tag2idx["<STOP>"]] = -10000
        self.transitions.data[self.tag2idx["<PAD>"], :] = -10000 
        self.transitions.data[:, self.tag2idx["<PAD>"]] = -10000

    # 负对数似然损失
    def neg_log_likelihood(self, sentences, tags, masks):
        emissions = self.forward(sentences, masks)
        batch_size = sentences.size(0)

        # 计算真实路径分数
        score = torch.zeros(batch_size).to(sentences.device)
        for i in range(batch_size):
            # 获取有效长度
            length = masks[i].sum().item()
            if length == 0:
                continue

            # 添加第一个转移（START->第一个标签）
            score[i] = self.transitions[self.tag2idx["<START>"], tags[i, 0]]
            score[i] += emissions[i, 0, tags[i, 0]]

            # 累加中间转移和发射分数
            for t in range(1, length):
                score[i] += self.transitions[tags[i, t - 1], tags[i, t]] + emissions[i, t, tags[i, t]]

            # 添加最后一个标签到STOP的转移
            score[i] += self.transitions[tags[i, length - 1], self.tag2idx["<STOP>"]]

        # 计算配分函数
        log_Z = self.compute_log_partition(emissions, masks)

        return (log_Z - score).mean()

    # 计算配分函数（log_Z）
    def compute_log_partition(self, emissions, masks):
        batch_size, seq_len, _ = emissions.shape
        device = emissions.device

        # 初始化alpha
        alpha = torch.full((batch_size, self.tagset_size), -10000.0).to(device)
        alpha[:, self.tag2idx["<START>"]] = 0.0

        for t in range(seq_len):
            # 获取当前时间步的mask
            mask_t = masks[:, t].unsqueeze(1)  # (batch, 1)
            current_emissions = emissions[:, t]  # (batch, tag_size)

            # 计算alpha[t] = logsumexp(alpha[t-1] + transitions + emissions[t])
            alpha_expanded = alpha.unsqueeze(2)  # (batch, tag_size, 1)
            trans_expanded = self.transitions.unsqueeze(0)  # (1, tag_size, tag_size)

            log_prob = alpha_expanded + trans_expanded + current_emissions.unsqueeze(1)
            new_alpha = torch.logsumexp(log_prob, dim=1)

            # 更新alpha
            alpha = torch.where(mask_t, new_alpha, alpha)

        # 最后加上到STOP的转移
        alpha += self.transitions[:, self.tag2idx["<STOP>"]].unsqueeze(0)
        return torch.logsumexp(alpha, dim=1)

    # emissions: (batch_size, seq_len, tagset_size)
    # mask: (batch_size, seq_len)
    def viterbi_decode(self, emissions, mask):
        batch_size, seq_len, _ = emissions.shape
        device = emissions.device

        # 初始化viterbi变量和backpointers
        viterbi = torch.full((batch_size, self.tagset_size), -10000.0, device=device)
        viterbi[:, self.tag2idx["<START>"]] = 0.0
        backpointers = torch.zeros((batch_size, seq_len, self.tagset_size), dtype=torch.long, device=device)

        for t in range(seq_len):
            # 获取当前时间步的mask
            mask_t = mask[:, t].unsqueeze(1)  # (batch_size, 1)

            # 计算所有路径分数
            scores = viterbi.unsqueeze(2) + self.transitions.unsqueeze(0)  # (batch_size, tag_size, tag_size)
            scores += emissions[:, t].unsqueeze(1)  # 广播发射分数

            # 找到最佳路径
            best_scores, best_tags = torch.max(scores, dim=1)

            # 更新viterbi和backpointers
            viterbi = best_scores * mask_t + viterbi * (~mask_t)  # 仅更新非padding位置
            backpointers[:, t] = best_tags

        # 添加STOP转移
        scores = viterbi + self.transitions[:, self.tag2idx["<STOP>"]].unsqueeze(0)
        _, best_tags = torch.max(scores, dim=1)

        # 回溯路径
        best_paths = []
        for i in range(batch_size):
            path = [best_tags[i].item()]
            for t in reversed(range(seq_len)):
                if not mask[i, t]:
                    continue  # 跳过padding位置
                path.append(backpointers[i, t, path[-1]].item())
            path.reverse()
            best_paths.append(path[1:])  # 去除START标签

        return best_paths

In [16]:
def predict(model, sentences, masks, device):
    model.eval()
    with torch.no_grad():
        emissions = model(sentences.to(device), masks.to(device))
        return model.viterbi_decode(emissions, masks.to(device))

In [17]:
def train_model(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch in dataloader:
        sentences, tags, masks = batch
        sentences, tags, masks = sentences.to(device), tags.to(device), masks.to(device)
        
        optimizer.zero_grad()   # 梯度清零
        loss = model.neg_log_likelihood(sentences, tags, masks) #  计算 CRF 负对数似然损失
        loss.backward()         # 反向传播
        optimizer.step()        # 更新参数
        
        total_loss += loss.item()
    
    return total_loss / len(dataloader)

In [18]:
# 读取训练集
def prepare_dataset(train_file):
    train_seqs, train_tags = [], []
    with open(train_file, "r", encoding="utf-8") as f:
        cur_sentence = []
        for line in f:
            line = line.strip()
            if not line:
                if cur_sentence:
                    train_seqs.append([word for word, _ in cur_sentence])
                    train_tags.append([tag for _, tag in cur_sentence])
                    cur_sentence = []
            else:
                parts = line.split()
                cur_sentence.append((parts[0], parts[1]))

    return train_seqs, train_tags

In [19]:
# 序列准备函数
def prepare_sequence(seq, to_ix, is_tags=False):
    if is_tags:
        return torch.tensor([to_ix[w] for w in seq], dtype=torch.long)
    else:
        return torch.tensor([to_ix.get(w, to_ix["<UNK>"]) for w in seq], dtype=torch.long)

# 数据集类
class NERDataset(Dataset):
    def __init__(self, sentences, tags, vocab, tag_to_ix):
        self.sentences = [prepare_sequence(s, vocab) for s in sentences]
        self.tags = [prepare_sequence(t, tag_to_ix, is_tags=True) for t in tags]

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return self.sentences[idx], self.tags[idx]

In [20]:
train_seqs, train_tags = prepare_dataset("./NER/Chinese/mytrain.txt")

# 构建词汇表和标签表
vocab = {"<PAD>": 0, "<UNK>": 1}
for sent in train_seqs:
    for word in sent:
        if word not in vocab:
            vocab[word] = len(vocab)

tag2idx = {"<PAD>": 0, "<START>": 1, "<STOP>": 2}
for tag_seq in train_tags:
    for tag in tag_seq:
        if tag not in tag2idx:
            tag2idx[tag] = len(tag2idx)

dataset = NERDataset(train_seqs, train_tags, vocab, tag2idx)

In [21]:
# 将一个 batch 的样本组合成一个 batch 的张量
# # 对句子和标签进行padding，使它们长度一致（补<PAD>到最大长度）
def collate_fn(batch):
    seqs, tags = zip(*batch)
    lens = torch.tensor([len(s) for s in seqs])
    seqs_padded = pad_sequence(seqs, batch_first=True, padding_value=vocab["<PAD>"])
    tags_padded = pad_sequence(tags, batch_first=True, padding_value=tag2idx["<PAD>"])
    masks = seqs_padded != vocab["<PAD>"]
    return seqs_padded, tags_padded, masks, lens

dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

In [22]:
model = CRF_TF(vocab_size=len(vocab), tag2idx=tag2idx, d_model=256, nhead=8, n_layers=2).to(device)

In [23]:
n_epochs = 10
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# 训练循环
for epoch in range(n_epochs):
    total_loss = 0
    for train_seqs, train_tags, masks, _ in dataloader:
        train_seqs = train_seqs.to(device)
        train_tags = train_tags.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        loss = model.neg_log_likelihood(train_seqs, train_tags, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {total_loss/len(dataloader):.4f}")

Epoch 1, Loss: 83.5791
Epoch 2, Loss: 62.0539
Epoch 3, Loss: 50.5144
Epoch 4, Loss: 43.0408
Epoch 5, Loss: 37.0852
Epoch 6, Loss: 32.9054
Epoch 7, Loss: 27.1854
Epoch 8, Loss: 23.3263
Epoch 9, Loss: 19.9205
Epoch 10, Loss: 17.8583


In [None]:
torch.save(model.state_dict(), "model/crf_tf.pth")

In [None]:
#测试模型
model = CRF_TF(vocab_size=len(vocab), tag2idx=tag2idx, d_model=256, nhead=8, n_layers=2).to(device)
model.load_state_dict(torch.load("crf_tf.pth", map_location=device))    # 加载模型参数
model.eval()

# 读取测试集
test_seqs, test_tags = prepare_dataset("test/chinese_test.txt")

# 预测并写入文件
output_path = "output/crf_tf_test_output.txt"
with open(output_path, "w", encoding="utf-8") as fout:
    for seq in test_seqs:
        seq_tensor = prepare_sequence(seq, vocab).unsqueeze(0).to(device)
        mask = (seq_tensor != vocab["<PAD>"])
        emissions = model(seq_tensor, mask)
        pred_ids = model.viterbi_decode(emissions, mask)[0]
        pred_tags = [list(tag2idx.keys())[list(tag2idx.values()).index(idx)] for idx in pred_ids[:len(seq)]]
        for word, tag in zip(seq, pred_tags):
            fout.write(f"{word} {tag}\n")
        fout.write("\n")

# 评测
check(language="Chinese", gold_path="pj2_test/chinese_test.txt", my_path=output_path)


              precision    recall  f1-score   support

      B-NAME     0.6429    0.4018    0.4945       112
      M-NAME     0.3818    0.2561    0.3066        82
      E-NAME     0.3684    0.0625    0.1069       112
      S-NAME     0.0000    0.0000    0.0000         0
      B-CONT     0.0000    0.0000    0.0000        28
      M-CONT     0.4767    0.7736    0.5899        53
      E-CONT     1.0000    0.5000    0.6667        28
      S-CONT     0.0000    0.0000    0.0000         0
       B-EDU     0.2821    0.0982    0.1457       112
       M-EDU     0.6364    0.0391    0.0737       179
       E-EDU     0.1471    0.0446    0.0685       112
       S-EDU     0.0000    0.0000    0.0000         0
     B-TITLE     0.1844    0.3675    0.2456       770
     M-TITLE     0.5833    0.5102    0.5443      1921
     E-TITLE     0.6417    0.7000    0.6696       770
     S-TITLE     0.0000    0.0000    0.0000         0
       B-ORG     0.2305    0.1449    0.1780       552
       M-ORG     0.6964    