In [None]:
from huggingface_hub import login
import wandb
login(token = "YOUR_KEY")
wandb.login(key = "YOUR_KEY")
run = wandb.init(
    project='Fine-tune gemma-3n-e4b',
    job_type="training",
    anonymous="allow"
)

In [None]:
csv_file_name = "consult_validation_cleaned.csv"
base_model_name = "unsloth/gemma-3n-E4B-it"
new_model_name = "gemma-3n-privnurse-consult-validation-v1"

In [None]:
from unsloth import FastModel
import torch

model, tokenizer = FastModel.from_pretrained(
    model_name = base_model_name,
    dtype = None, # None for auto detection
    max_seq_length = 8192, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    full_finetuning = False, # [NEW!] We have full finetuning now!
)

In [None]:
###########################
# Let's finetune Gemma 3N #
###########################

In [None]:
model = FastModel.get_peft_model(
    model,
    finetune_vision_layers     = False, # Turn off for just text!
    finetune_language_layers   = True,  # Should leave on!
    finetune_attention_modules = True,  # Attention good for GRPO
    finetune_mlp_modules       = True,  # SHould leave on always!

    r = 32,           # Larger = higher accuracy, but might overfit
    lora_alpha = 64,  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
)

In [None]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma-3",
)

In [None]:
from datasets import load_dataset
dataset = load_dataset('csv', data_files=csv_file_name, split="all")
dataset = dataset.shuffle(seed=56)
dataset[100]

In [None]:
# System message for the assistant 
system_prompt = "Extract the exact phrases or sentences from the [#會診申請單] text that correspond to the information summarized in the [#護理師確認結果]. Present these extracted phrases as an array of strings in JSON format, with the key 'relevant_text'. Do not output anything other than the JSON array."

In [None]:
def formatting_prompts_func(examples):
    # 構建 conversations 列表，每個元素包含 system, user, assistant
    convos = []
    for i in range(len(examples["original"])):
        convo = [
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": f"#會診申請單：\n{examples['original'][i]}\n\n#護理師確認結果：\n{examples['summary'][i]}"},
            {"role": "assistant", "content": examples["claude_output"][i]}
        ]
        convos.append(convo)
    
    # 應用 chat template 並移除 <bos> 前綴
    texts = [tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False).removeprefix('<bos>') for convo in convos]
    
    return {"text": texts}

# 使用批次處理
dataset = dataset.map(formatting_prompts_func, batched=True)

In [None]:
dataset[100]["text"]

In [None]:
###################
# Train the model #
###################

In [None]:
from trl import SFTTrainer, SFTConfig
training_arguments = SFTConfig(
    dataset_text_field="text",
    output_dir=new_model_name,
    per_device_train_batch_size=6,     # 預設2
    gradient_accumulation_steps=8,     # 預設8
    optim="adamw_torch_fused",         # Options: adamw_hf, adamw_torch, adamw_torch_fused, adamw_8bit
    num_train_epochs=6,                
    ###############
    # per_device_eval_batch_size=2,    # 預設2
    # evaluation_strategy="steps",
    eval_strategy='no',
    # eval_steps=10,
    ###############
    lr_scheduler_type = "linear",      
    max_grad_norm=0.3,                 # Default: 0.3
    warmup_ratio=0.1,                 # Default: 0.03
    # warmup_steps=30,                 # Default: 30
    learning_rate=1e-3,                # Default: 2e-4
    # weight_decay=0.01,               # [new]Default: 0
    # adam_beta1=0.9,                  # [new]Default: 0.9
    # adam_beta2=0.95,                 # [new]Default: 0.999
    # seed=1205,
    logging_strategy="steps",
    logging_steps=4,                # When batch size is big, logging steps should be reduced to 2-5 not 10
    save_strategy="epoch",
    fp16=False,
    bf16=True,
    # group_by_length=True,
    report_to="wandb",
    max_seq_length=8192,
    seed = 3407,
    # packing= False,
)

In [None]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # Can set up evaluation!
    args = training_arguments,
)

In [None]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

In [None]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

In [None]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

In [None]:
# [Bug]: InductorError: RuntimeError: Failed to run autotuning code block: PY_SSIZE_T_CLEAN macro must be defined for '#' formats
# Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
# [Solution]: Disable TorchInductor Compilation. Add this before your inference code:
import torch
torch._dynamo.config.disable = True
# Or alternatively, disable specific optimizations
# torch._dynamo.config.suppress_errors = True
# >> The error is related to PyTorch's aggressive optimization trying to compile the vision components of Gemma 3, which can be disabled without affecting functionality.)

In [None]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer_stats = trainer.train()

In [None]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

In [None]:
wandb.finish()

In [None]:
#############
# Inference #
#############

In [None]:
# [Bug]: InductorError: RuntimeError: Failed to run autotuning code block: PY_SSIZE_T_CLEAN macro must be defined for '#' formats
# Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
# [Solution]: Disable TorchInductor Compilation. Add this before your inference code:
import torch
torch._dynamo.config.disable = True
# Or alternatively, disable specific optimizations
# torch._dynamo.config.suppress_errors = True
# >> The error is related to PyTorch's aggressive optimization trying to compile the vision components of Gemma 3, which can be disabled without affecting functionality.)

In [None]:
from unsloth.chat_templates import get_chat_template
from transformers import TextStreamer

def process_consultation(model, tokenizer, consultation_content, max_tokens=2048):
    tokenizer = get_chat_template(
        tokenizer,
        chat_template = "gemma-3",
    )
    messages = [
        {"role": "system",  "content": [{"type": "text", "text": system_prompt}]},
        {"role": "user", "content": [{"type": "text", "text": consultation_content}]}
    ]
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True, # Must add for generation
        return_tensors = "pt",
        tokenize = True,
        return_dict = True,
    ).to("cuda")

    _ = model.generate(
        **inputs,
        max_new_tokens = max_tokens,
        ##############################################################################
        # temperature = 1.0, top_p = 0.95, top_k = 64, # Recommended Gemma-3 settings!
        ##############################################################################
        temperature = 0.1,
        top_p = 0.95,
        top_k = 64,
        ##############################################################################
        streamer = TextStreamer(tokenizer, skip_prompt = True),
    )

In [None]:
consultation_content = """
#會診申請單：
申請會診內容：
Dear Dr
This 45 years old woman admitted due to influenza A infection, beaause of progressive desaturation with fever on and off, follow CXR increased bilateral infiltration, under tapimycin, we need your consult for continue Tapimycin therapy. thanks a lot.

被會診諮詢的科別：感染科

回覆醫師：楊XX

感染科回覆會診內容：
Agree with continuing Tapimycin 4.5 gm iv q8h. (8523)

護理師確認結果：
病患因流感A感染入院，因持續低氧及間歇性發燒，胸部X光顯示雙側浸潤增加，申請感染科會診，感染科楊清鎮醫師回覆，同意繼續使用Tapimycin 4.5 gm iv q8h。
"""
process_consultation(model, tokenizer, consultation_content)
# {"relevant_text": ["This 45 years old woman admitted due to influenza A infection, beaause of progressive desaturation 
# with fever on and off", "follow CXR increased bilateral infiltration", "we need your consult for continue Tapimycin therapy", 
# "Agree with continuing Tapimycin 4.5 gm iv q8h. (8523)"]}

In [None]:
consultation_content = """
#會診申請單：
申請會診內容：
Dear doctor:today anti maturity 
This is 67y/o femlae becasue fell accident after discharge on 7/14. Now, Severe low back pain(VAS 8) and cannot walk now. Wheel chair+ .patient severe lower back pain and right leg soreness .follow L+T spine MRI revealed suspect L3-4 infection .today lab data get worse .spsepct sepsis .so we need your evalaution for antib use  .
Thanks.
8/13: anti 到期 thanks 

被會診諮詢的科別：感染科

回覆醫師：楊XX

感染科回覆會診內容：
Agree with continuing Mepem 500 mg iv Q12h + Tecopin 400 mg iv Q3d. (8523)

護理師確認結果：
病患因今日抗生素到期，申請感染科會診。感染科楊XX醫師回覆，同意繼續使用Mepem 500 mg iv Q12h及Tecopin 400 mg iv Q3d。
"""
process_consultation(model, tokenizer, consultation_content)
# {"relevant_text": ["8/13: anti 到期 thanks", "Agree with continuing Mepem 500 mg iv Q12h + Tecopin 400 mg iv Q3d. (8523)"]}

In [None]:
consultation_content = """
#會診申請單：
申請會診內容：
Dear Dr 
   This 68 y/o man is a case of bilateral pneumonia,abdomen fullness and Ileus 
noted 
we need your expertise for this patient 
Thank you very much  
113/08/02
stool impaction noted,we need your expertise for this patient 
Thank you very much 

被會診諮詢的科別：肝膽腸胃科

回覆醫師：許XX

肝膽腸胃科回覆會診內容：
Dear : Doctor in charge.


Impression :
1. Suspect sigmoid volvulus due to chronic severe constipation.
2. Constipation.


Suggest : 
1. Consider to do colonoscopy with endoscopic detorsion first.
2. Consider to do surgery if not improved with endoscopic detrosion.


Thank you for yours consultation !!!

# 護理師確認結果：
病患因雙側肺炎、腹脹及腸阻塞，申請肝膽腸胃科會診。肝膽腸胃科許XX醫師回覆診斷為可能為乙狀結腸扭轉伴隨慢性嚴重便秘，並建議考慮進行結腸鏡內視鏡復位，若內視鏡復位無改善則考慮手術。
"""
process_consultation(model, tokenizer, consultation_content)
# {"relevant_text": ["This 68 y/o man is a case of bilateral pneumonia,abdomen fullness and Ileus noted", 
# "stool impaction noted,we need your expertise for this patient", "Suspect sigmoid volvulus due to chronic severe constipation.", 
# "Consider to do colonoscopy with endoscopic detorsion first.", "Consider to do surgery if not improved with endoscopic detrosion."]}

In [None]:
consultation_content = """
#會診申請單：
## 申請會診內容：
Dear Dr.:
  This is a 77 years old male patient who was admitted due to fever. Home test COVID Ag: (+) at home. He has history of 1. Type 2 diabetes mellitus, 2. Chronic kidney disease under medication. Laboratory data showed pyuria and acute kidney injury. Impression of 1. Urinary tract infection, 2. COVID-19 infection.
Hence ,we need your assessment for antibiotic treatment, thanks a lot!

for second course of Cravit, thank you.

## 被會診諮詢的科別：感染科

## 回覆醫師：陳XX

## 感染科回覆會診內容：
無

# 護理師確認結果：
會診感染科徵求抗生素使用建議，等待回覆
"""
process_consultation(model, tokenizer, consultation_content)
# {"relevant_text": ["Hence,we need your assessment for antibiotic treatment, thanks a lot!", "for second course of Cravit, thank you."]}

In [None]:
consultation_content = """
#會診申請單：
申請會診內容：
Dear Dr.
This 65 years old male patient is a case of Enlarged prostate with lower urinary tract symptoms S/P Laser transurethral resection of the Prostate on 08/01
Due to high fever and leukocytosis were noted, so we need your expert to this patient for evaluation 
血液檢查 Blood
日     期         WBC             RBC         Hb     HCT      MCV 
= = = = =  ========== =============== ========== ======= ======== 
113/08/02   23730 uL  487 x10000/uL  13.7 g/dL  39.4 %  80.9 fl 

血液檢查 Blood
日     期       MCH       MCHC       Platelet  RDW-CV     MPV 
= = = = =  ======== ========== ============== ======= ======= 
113/08/02   28.1 pg  34.8 g/dL  269 x1000/uL  13.3 %  8.9 fL 

血液檢查 Blood
日     期   Neutrophil-Segmented  Lymphocyte  Monocyte 
= = = = =  ===================== =========== ========= 
113/08/02                 87.9 %       7.6 %     4.5 % 

血液檢查 Blood
日     期   Absolute neutrophil count(ANC) 
= = = = =  =============================== 
113/08/02                        20859 /uL 

血液檢查 Blood
Thank a lot!! 

被會診諮詢的科別：感染科

回覆醫師：楊XX

感染科回覆會診內容：
Switch Stazolin iv to Seforce 400 mg iv Q12h (CODE: 8523).

護理師確認結果：
病患因高燒及白血球增多，申請感染科會診。感染科楊XX醫師回覆，同意將Stazolin iv改為Seforce 400 mg iv Q12h。
"""
process_consultation(model, tokenizer, consultation_content)
# {"relevant_text": ["Due to high fever and leukocytosis were noted, so we need your expert to this patient for evaluation", 
# "被會診諮詢的科別：感染科", "回覆醫師：楊XX", "Switch Stazolin iv to Seforce 400 mg iv Q12h (CODE: 8523)."]}

In [None]:
##########################################
############# Save the model #############
##########################################

In [None]:
# [NOTE] This ONLY saves the LoRA adapters, and not the full model. To save to GGUF, scroll down!
model.save_pretrained(new_model_name)  # Local saving
tokenizer.save_pretrained(new_model_name)

In [None]:
model.save_pretrained_merged(new_model_name, tokenizer)

In [None]:
model.save_pretrained_gguf(
    new_model_name,
    quantization_type = "Q8_0", # For now only Q8_0, BF16, F16 supported
)