In [1]:
import pandas as pd
import json
from datasets import Dataset, concatenate_datasets
from unsloth import FastLanguageModel
from transformers import TrainingArguments, DataCollatorForSeq2Seq
from trl import SFTTrainer
from unsloth.chat_templates import get_chat_template, train_on_responses_only

# ‚úÖ 1. ËØªÂèñ Excel Êï∞ÊçÆÈõÜ
xlsx_path = "/root/autodl-tmp/result_with_accuracy.xlsx"
df_xlsx = pd.read_excel(xlsx_path, usecols=["Disease", "Symptoms", "Treatments"])
df_xlsx.fillna("", inplace=True)  # ‚úÖ Â§ÑÁêÜÁº∫Â§±ÂÄº

dataset_xlsx = Dataset.from_pandas(df_xlsx)

# ‚úÖ 2. ËØªÂèñ JSON Êï∞ÊçÆÈõÜ
json_path = "/root/autodl-tmp/updated_data.json"
with open(json_path, "r", encoding="utf-8") as f:
    json_data = json.load(f)

formatted_qa = []
formatted_textbook = []

for item in json_data:
    if item["type"] == "qa":
        user_input = f"<|start_header_id|>user<|end_header_id|>\n\n{item['question']}\n\n"
        assistant_response = f"<|start_header_id|>assistant<|end_header_id|>\n\n{item['answer']}\n\n"
        formatted_qa.append({"text": user_input + assistant_response})
    
    elif item["type"] == "textbook":
        assistant_response = f"<|start_header_id|>assistant<|end_header_id|>\n\n{item['text']}\n\n"
        formatted_textbook.append({"text": assistant_response})

dataset_qa = Dataset.from_list(formatted_qa)
dataset_textbook = Dataset.from_list(formatted_textbook)

# ‚úÖ 3. Â§ÑÁêÜ Excel Êï∞ÊçÆ
def formatting_prompts_func(examples):
    formatted_texts = []
    for disease, symptoms, treatments in zip(examples["Disease"], examples["Symptoms"], examples["Treatments"]):
        user_input = f"<|start_header_id|>user<|end_header_id|>\n\n{disease}\n\n"
        assistant_response = f"<|start_header_id|>assistant<|end_header_id|>\n\n"
        assistant_response += f"Symptoms:\n{symptoms}\n\n" if symptoms else ""
        assistant_response += f"Treatments:\n{treatments}\n\n" if treatments else ""
        formatted_texts.append(user_input + assistant_response)
    return {"text": formatted_texts}

dataset_xlsx = dataset_xlsx.map(formatting_prompts_func, batched=True)

# ‚úÖ 4. ÂêàÂπ∂Êï∞ÊçÆÈõÜ
dataset_combined = concatenate_datasets([dataset_xlsx, dataset_qa, dataset_textbook])

# ‚úÖ 5. Âä†ËΩΩ `unsloth/Llama-3.1-8B-Instruct-bnb-4bit`
model_name = "/root/autodl-tmp/Llama-3.1-8B-Instruct-bnb-4bit"
max_seq_length = 2048

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=model_name,
    max_seq_length=max_seq_length,
    load_in_4bit=True,
)

# ‚úÖ 6. Â∫îÁî® `LoRA` ËÆ≠ÁªÉÂèÇÊï∞ÔºàÂáèÂ∞ëÂΩ±ÂìçÔºâ
model = FastLanguageModel.get_peft_model(
    model,
    r=8,  # ‚¨áÔ∏è Èôç‰Ωé LoRA ÈÄÇÈÖçÂΩ±ÂìçÔºà‰ªé 16 ÈôçÂà∞ 8Ôºâ
    target_modules=["q_proj", "v_proj", "o_proj"],  # ‚¨áÔ∏è Âè™ÈÄÇÈÖçËøô‰∏â‰∏™ÂÖ≥ÈîÆÂ±Ç
    lora_alpha=8,  # ‚¨áÔ∏è Èôç‰Ωé LoRA ÊùÉÈáçÔºà‰ªé 16 ÈôçÂà∞ 8Ôºâ
    lora_dropout=0.05,  # ÈÄÇÂΩìÂä†‰∏ÄÁÇπ dropout Èò≤Ê≠¢ËøáÊãüÂêà
    bias="none",
    use_gradient_checkpointing="unsloth",
)

# ËÆæÂÆöÂØπËØùÊ®°Êùø
tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")

# ‚úÖ 7. ËÆ≠ÁªÉÂô®ÔºàÂáèÂ∞ëÂ≠¶‰π†ÁéáÂΩ±ÂìçÔºâ
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset_combined,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
    dataset_num_proc=2,
    packing=False,
    args=TrainingArguments(
        per_device_train_batch_size=8,  # ‚¨ÜÔ∏è Â¢ûÂ§ß batch sizeÔºåÂáèÂ∞ëÊ¢ØÂ∫¶Ê≥¢Âä®
        gradient_accumulation_steps=1,  # ‚¨áÔ∏è Èôç‰ΩéÁ¥ØËÆ°Ê¢ØÂ∫¶ÔºåÂáèÂ∞ë LoRA ÂΩ±Âìç
        warmup_steps=10,  # ÈÄÇÂΩìÂ¢ûÂä† warmup
        max_steps=500,  # ‚¨áÔ∏è ÂáèÂ∞ëËÆ≠ÁªÉÊ≠•Êï∞ÔºåÈÅøÂÖçËøáÊãüÂêà
        learning_rate=5e-5,  # ‚¨áÔ∏è Èôç‰ΩéÂ≠¶‰π†ÁéáÔºåÂáèÂ∞ëÂØπÂéüÂßãÊ®°ÂûãÁöÑÂΩ±Âìç
        fp16=False,  
        bf16=True,  
        logging_steps=1,
        optim="adamw_8bit",
        weight_decay=0.01,
        lr_scheduler_type="cosine",  # ‚úÖ Êîπ‰∏∫ `cosine` ËÆ©Â≠¶‰π†ÁéáÂπ≥Êªë‰∏ãÈôç
        seed=3407,
        output_dir="trained_model",
        report_to="none",
    ),
)

# ‚úÖ 8. ËÆ≠ÁªÉ‰ªÖÂÖ≥Ê≥® `assistant` ÂõûÁ≠îÈÉ®ÂàÜ
trainer = train_on_responses_only(
    trainer,
    instruction_part="<|start_header_id|>user<|end_header_id|>\n\n",
    response_part="<|start_header_id|>assistant<|end_header_id|>\n\n",
)

# ‚úÖ 9. ÂºÄÂßãËÆ≠ÁªÉ
trainer.train()

# ‚úÖ 10. ‰øùÂ≠òÊ®°Âûã
trainer.model.save_pretrained("/root/autodl-tmp/trained_model")
trainer.tokenizer.save_pretrained("/root/autodl-tmp/trained_model")


ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.
ü¶• Unsloth Zoo will now patch everything to make training faster!


Map:   0%|          | 0/410 [00:00<?, ? examples/s]

==((====))==  Unsloth 2025.2.5: Fast Llama patching. Transformers: 4.48.3.
   \\   /|    GPU: NVIDIA GeForce RTX 4090. Max memory: 23.643 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.
Unsloth 2025.2.5 patched 32 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


Map (num_proc=2):   0%|          | 0/276485 [00:00<?, ? examples/s]

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


Map:   0%|          | 0/276485 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 276,485 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 8 | Gradient Accumulation steps = 1
\        /    Total batch size = 8 | Total steps = 500
 "-____-"     Number of trainable parameters = 5,505,024


Step,Training Loss
1,2.737
2,2.505
3,2.6225
4,2.3369
5,2.3538
6,2.2634
7,2.2464
8,2.11
9,2.6174
10,2.196


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.


('/root/autodl-tmp/trained_model/tokenizer_config.json',
 '/root/autodl-tmp/trained_model/special_tokens_map.json',
 '/root/autodl-tmp/trained_model/tokenizer.json')