In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5,6"
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "../models/flan-t5-base"  # ✅ 你的本地模型路径
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)  
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, local_files_only=True)
from datasets import load_dataset  
dataset = load_dataset("json", data_files={ 
    "train": "../datasets/pdtb3_train_level1_T5_mask.jsonl",
    "validation": "../datasets/pdtb3_dev_level1_T5_mask.jsonl",
    "test": "../datasets/pdtb3_test_level1_T5.jsonl"
})
dataset["train"][0]

def preprocess(example):
    input_text = (
        f"Arg1: {example['Arg1']}  "
        f"Arg2: {example['Arg2']}  "
        "What is the discourse relation?"
    )
    target_text = example["Label"]

    model_inputs = tokenizer(
        input_text, max_length=512, truncation=True
    )
    labels = tokenizer(
        target_text, max_length=8, truncation=True
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# 4️⃣ 批量映射


In [2]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

# 定义 data_collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

# 设置训练参数
training_args = Seq2SeqTrainingArguments(
    output_dir="../checkpoints/flan-t5-pdtb3-level1",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=5,
    predict_with_generate=True,
    fp16=True,
    logging_dir="../logs",
    logging_steps=50,
    push_to_hub=False,  
    report_to="none",  
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

  trainer = Seq2SeqTrainer(
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Epoch,Training Loss,Validation Loss
1,0.2654,0.240837
2,0.2136,0.221133
3,0.2213,0.218476
4,0.2089,0.217022
5,0.2002,0.220555




TrainOutput(global_step=5575, training_loss=0.241820055871801, metrics={'train_runtime': 1582.4673, 'train_samples_per_second': 56.355, 'train_steps_per_second': 3.523, 'total_flos': 1.2293464797499392e+16, 'train_loss': 0.241820055871801, 'epoch': 5.0})

In [4]:
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sklearn.metrics import accuracy_score, f1_score, classification_report

# 1. 加载模型和分词器
model_name = "/data/ykyang/checkpoints/flan-t5-pdtb3-level1/checkpoint-3345"  # 你的保存路径
tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True)

model = AutoModelForSeq2SeqLM.from_pretrained(model_name, local_files_only=True).cuda()
model.eval()

# 2. 准备数据
test_set = dataset["test"]

# 3. 推理
preds, refs = [], []
for example in tqdm(test_set):
    # 拼接输入
    input_text = f"Arg1: {example['Arg1']} Arg2: {example['Arg2']}"
    ref = example["Label"].strip()

    input_ids = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=512).input_ids.cuda()

    outputs = model.generate(input_ids, max_new_tokens=10)
    pred = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    preds.append(pred)    
    refs.append(ref)

# 4. 计算指标
print("Accuracy:", accuracy_score(refs, preds)) 
print("F1 (macro):", f1_score(refs, preds, average='macro'))
print("\nClassification report:\n", classification_report(refs, preds))


100%|██████████| 1010/1010 [01:08<00:00, 14.70it/s]

Accuracy: 0.692079207920792
F1 (macro): 0.5901672740966485

Classification report:
               precision    recall  f1-score   support

  Comparison       0.68      0.48      0.56       138
 Contingency       0.75      0.61      0.67       340
   Expansion       0.68      0.85      0.76       480
    Temporal       0.46      0.31      0.37        52

    accuracy                           0.69      1010
   macro avg       0.64      0.56      0.59      1010
weighted avg       0.69      0.69      0.68      1010




