In [None]:
!pip install transformers datasets accelerate peft  # 安装核心库

In [ ]:
import json
import torch
from datasets import load_dataset

from transformers import (
    AutoModel,
    AutoModelForCausalLM,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    DataCollatorForSeq2Seq
)



def preprocess_data(split):
    def process(example):
        # 训练集处理
        if split == "train":
            instruction = "分析诗歌《{}[](@replace=10001)》：\n{}\n请完成：1.解释关键词 2.翻译全诗 3.识别情感".format(
                example["title"], example["content"]
            )
            output = {
                "keywords": example["keywords"],
                "translation": example["trans"],
                "emotion": example["emotion"].split("、")[0]
            }
            return {"instruction": instruction, "input": "", "output": json.dumps(output, ensure_ascii=False)}
        
        # 验证集处理（网页3）
        elif split == "validation":
            instruction = "请完成以下任务：\n1.解释词语：{}\n2.翻译句子：{}\n3.选择情感：{}".format(
                "、".join(example["qa_words"]),
                "；".join(example["qa_sents"]),
                " | ".join([f"{k}:{v}" for k,v in example["choose"].items()])
            )
            return {
                "instruction": instruction,
                "input": f"《{example['title']}[](@replace=10002)》\n{example['content']}",
                "idx": example["idx"]
            }
    return process

# 自定义数据加载（网页8）
def load_custom_data(train_path, val_path):
    train_data = load_dataset("json", data_files=train_path, split="train")
    val_data = load_dataset("json", data_files=val_path, split="train")
    
    train_data = train_data.map(preprocess_data("train"), remove_columns=train_data.column_names)
    val_data = val_data.map(preprocess_data("validation"), remove_columns=val_data.column_names)
    
    return train_data, val_data

# 模型与分词器初始化（网页2、网页6）    
model_name = "THUDM/chatglm3-6b"  # 支持中文的轻量级模型
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)

# 数据整理器（网页1、网页7）
collator = DataCollatorForSeq2Seq(
    tokenizer,
    padding=True,
    max_length=1024,
    pad_to_multiple_of=8
)

# 训练参数（网页4、网页5）
training_args = TrainingArguments(
    output_dir="./output",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    num_train_epochs=3,
    fp16=True,
    evaluation_strategy="steps",
    eval_steps=200,
    logging_steps=50,
    save_strategy="steps",
    save_steps=300,
    report_to="tensorboard",
    gradient_checkpointing=True,  # 显存优化（网页6）
    dataloader_num_workers=4,
    predict_with_generate=True
)

# 自定义评估生成（网页3）
def generate_predictions(model, tokenizer, val_data):
    results = []
    for example in val_data:
        inputs = tokenizer(
            f"{example['instruction']}\n{example['input']}",
            return_tensors="pt",
            max_length=1024,
            truncation=True
        ).to(model.device)
        
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.01,
            do_sample=False
        )
        
        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # 结果解析（网页1）
        try:
            json_str = decoded.split("{")[1].split("}")[0]
            result = json.loads("{" + json_str + "}")
        except:
            result = {"ans_qa_words": {}, "ans_qa_sents": {}, "choose_id": ""}
        
        results.append({
            "idx": example["idx"],
            **result
        })
    return results

# 训练流程
if __name__ == "__main__":
    # 加载数据
    train_data, val_data = load_custom_data("train.json", "val.json")
    
    # 初始化Trainer（网页1、网页3）
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data,
        eval_dataset=val_data,
        data_collator=collator,
    )
    
    # 执行训练
    trainer.train()
    
    # 验证集推理
    predictions = generate_predictions(trainer.model, tokenizer, val_data)
    
    # 保存结果
    with open("predictions.json", "w") as f:
        json.dump(predictions, f, indent=2, ensure_ascii=False)