In [None]:
# 初始化 CrossEntropyLoss 时传递 ignore_index
epochs = 3
pad_token_label_id = -100
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=pad_token_label_id)

def get_embeddings(tokens):
    inputs = tokenizer(tokens, return_tensors="pt", is_split_into_words=True, padding=True, truncation=True)
    with torch.no_grad():
        outputs = bert_model(**inputs)
    return outputs.last_hidden_state, inputs['attention_mask']  # 返回 embeddings 和 attention_mask

def pad_labels(labels, max_length, pad_token_label_id):
    """对 labels 进行 padding，长度补齐到 max_length"""
    labels = labels + [pad_token_label_id] * (max_length - len(labels))
    return labels

for epoch in range(epochs):
    batch_counter = 0  # 初始化 batch 计数器
    for batch in train_dataset:
        tokens = batch['tokens']
        labels = batch['ner_tags']

        # 获取 tokens 的 embeddings 和 attention mask
        embeddings, attention_mask = get_embeddings(tokens)

        # 对 labels 进行 padding，长度与 embeddings 的 sequence length (12) 匹配
        padded_labels = pad_labels(labels, embeddings.size(1), pad_token_label_id)

        # 转换为 tensor 并展平
        outputs = model(embeddings)
        outputs = outputs.view(-1, num_labels)  # 将输出展平为 [batch_size * seq_length, num_labels]
        padded_labels = torch.tensor(padded_labels).view(-1)  # 将 labels 展平为 [batch_size * seq_length]

        # 计算损失，此时不再传递 ignore_index 参数
        loss = loss_fn(outputs, padded_labels)

        # 反向传播并更新权重
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 输出当前 batch 计数器和损失值
        batch_counter += 1
        print(f'Epoch {epoch+1}, Batch {batch_counter}, Loss: {loss.item()}')

    # 每个 epoch 完成后输出一次
    print(f"Epoch {epoch+1} completed.')
