In [1]:
from datetime import datetime
import pandas as pd
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import os
import json
from io import BytesIO
import base64

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
exp_name = "gemma3-4b-sft-textonly-"
'''
本次實驗名稱
textonly:對應dataset/0513_SFTDataset/text/qa_pairs_sft.json，只包含文字問題
hitdata:對應dataset/0513_SFTDataset/hitdata/sft_training_data.json，為圖文對資料
'''

In [None]:
file_locate = "/tmp/pycharm_project_979/" #遠端環境，根據部署地做更改
jsonFile_path = file_locate+"dataset/0513_SFTDataset/text/qa_pairs_sft.json"

with open(jsonFile_path, "r", encoding="utf-8") as f:
    dataset = json.load(f)

# 微調階段

In [3]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

In [4]:
from accelerate import Accelerator
from accelerate import PartialState
# Hugging Face model id
model_id = "google/gemma-3-4b-pt" # or `google/gemma-3-12b-pt`, `google/gemma-3-27-pt`

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

device_string = PartialState().process_index
# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU，安裝flash attn前必須先安裝ninja，詳細:https://blog.csdn.net/lckj2009/article/details/136054392
    torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map={'':device_string}
)

 #BitsAndBytesConfig int-4 config
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=model_kwargs["torch_dtype"],
    bnb_4bit_quant_storage=model_kwargs["torch_dtype"],
)

# Load model and tokenizer
model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.57s/it]
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [8]:
print(model)

Gemma3ForConditionalGeneration(
  (model): Gemma3Model(
    (vision_tower): SiglipVisionModel(
      (vision_model): SiglipVisionTransformer(
        (embeddings): SiglipVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
          (position_embedding): Embedding(4096, 1152)
        )
        (encoder): SiglipEncoder(
          (layers): ModuleList(
            (0-26): 27 x SiglipEncoderLayer(
              (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
              (self_attn): SiglipAttention(
                (k_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
                (v_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
                (q_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
                (out_proj): Linear4bit(in_features=1152, out_features=1152, bias=True)
              )
              (layer_norm2): LayerNorm((1152,), eps=

In [5]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

In [7]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="model/gemma-test",     # directory to save and repository id
    num_train_epochs=1,                         # number of training epochs
    per_device_train_batch_size=4,              # batch size per device during training
    gradient_accumulation_steps=1,              # number of steps before performing a backward/update pass
    gradient_checkpointing=True,                # use gradient checkpointing to save memory
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=5,                            # log every 5 steps
    save_strategy="epoch",                      # save checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                          # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    report_to="tensorboard",                    # report metrics to tensorboard
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)
args.remove_unused_columns = False # important for collator

# Create a data collator to encode text and image pairs
def collate_fn(examples):
    texts = []
    images = []
    for example in examples:
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

    # Tokenize the texts and process the images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens and image tokens in the loss computation
    labels = batch["input_ids"].clone()

    # Mask image tokens
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    # Mask tokens for not being used in the loss computation
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels.cpu()
    return batch

In [8]:
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,

)

[2025-08-24 09:46:13,748] [INFO] [real_accelerator.py:260:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/usr/bin/ld: cannot find -laio
collect2: error: ld returned 1 exit status
/usr/bin/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlopen'
/usr/bin/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlclose'
/usr/bin/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlerror'
/usr/bin/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `dlsym'
collect2: error: ld returned 1 exit status


[2025-08-24 09:46:14,868] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False


In [9]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


ValueError: Default process group has not been initialized, please make sure to call init_process_group.

In [10]:
print(model)

Gemma3ForConditionalGeneration(
  (model): Gemma3Model(
    (vision_tower): SiglipVisionModel(
      (vision_model): SiglipVisionTransformer(
        (embeddings): SiglipVisionEmbeddings(
          (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
          (position_embedding): Embedding(4096, 1152)
        )
        (encoder): SiglipEncoder(
          (layers): ModuleList(
            (0-26): 27 x SiglipEncoderLayer(
              (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
              (self_attn): SiglipAttention(
                (k_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=1152, out_features=1152, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.05, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=1152, out_features=16, bias=False)
                  )
   

In [35]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()

## 合併Lora

In [36]:
from peft import PeftModel

# Load Model base model
model = AutoModelForImageTextToText.from_pretrained(model_id, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, args.output_dir)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("merged_model", safe_serialization=True, max_shard_size="2GB")

processor = AutoProcessor.from_pretrained(args.output_dir)
processor.save_pretrained("merged_model")

Loading checkpoint shards: 100%|██████████| 2/2 [00:04<00:00,  2.22s/it]


['merged_model/processor_config.json']

##　載入已微調模型

In [37]:
from transformers import Gemma3ForConditionalGeneration
import torch

# Load Model with PEFT adapter
model = Gemma3ForConditionalGeneration.from_pretrained(
  args.output_dir,
  device_map="auto",
  torch_dtype=torch.bfloat16,
  attn_implementation="eager",
  output_hidden_states=True,
  output_attentions=True
)
processor = AutoProcessor.from_pretrained(args.output_dir)

Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 40.69it/s]
The following generation flags are not valid and may be ignored: ['output_attentions', 'output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_attentions', 'output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.05s/it]
The following generation flags are not valid and may be ignored: ['output_attentions', 'output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['output_attentions', 'output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


# collect distill data

In [38]:
import os, io, base64, numpy as np, pandas as pd, torch
from tqdm import tqdm

def to_tensor_list(x):
    if x is None:
        return []
    if isinstance(x, (list, tuple)):
        return list(x)
    return [x]

def serialize_tensor_list(tensors) -> str:
    if not tensors:
        return ""
    buf = io.BytesIO()
    arrays = {f"arr_{i}": (t.squeeze(0).to(torch.float16).cpu().numpy() if isinstance(t, torch.Tensor) else np.array(t))
              for i, t in enumerate(tensors)}
    np.savez_compressed(buf, **arrays)
    return base64.b64encode(buf.getvalue()).decode("ascii")

# 確保會輸出所需中間特徵
model.config.output_hidden_states = True
model.config.output_attentions = True
if hasattr(model.config, "vision_config"):
    model.config.vision_config.output_hidden_states = True

rows = []
print("Generating and saving teacher signals to single CSV...")
for idx, sample in enumerate(tqdm(dataset, desc="Processing samples")):
    text_prompt = processor.apply_chat_template(
        sample["messages"], tokenize=False, add_generation_prompt=False
    ).strip()
    image_inputs = process_vision_info(sample["messages"])
    inputs = processor(text=[text_prompt], images=[image_inputs], return_tensors="pt", padding=False).to(model.device)

    with torch.inference_mode():
        out = model(**inputs)

    hs_b64 = serialize_tensor_list(to_tensor_list(getattr(out, "hidden_states", None)))
    attn_b64 = serialize_tensor_list(to_tensor_list(getattr(out, "attentions", None)))
    img_hs_b64 = serialize_tensor_list(to_tensor_list(getattr(out, "image_hidden_states", None)))

    rows.append({
        "id": idx,
        "teacher_hidden_states_b64": hs_b64,
        "teacher_attentions_b64": attn_b64,
        "teacher_image_hidden_states_b64": img_hs_b64,
    })

df = pd.DataFrame(rows)
csv_path = "dataset/distill_teacher_signals.csv"  # 單檔 CSV（gzip壓縮）
df.to_csv(csv_path, index=False)
print(f"Saved {len(df)} rows to {csv_path}")

Generating and saving teacher signals to single CSV...


Processing samples: 100%|██████████| 50/50 [10:25<00:00, 12.51s/it]


Saved 50 rows to dataset/distill_teacher_signals.csv


In [39]:

torch.cuda.empty_cache()

# Distill state

In [1]:
from torch.nn import MSELoss
from trl import SFTTrainer
from overrides import overrides
import torch

from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BaseImageProcessor,
    DataCollator,
    FeatureExtractionMixin,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    ProcessorMixin,
    Trainer,
    TrainingArguments,
    is_wandb_available,
)


if is_wandb_available():
    import wandb

def get_cor_teacher(teacher_reps, student_reps, is_attn=False):
    """
    Selects the corresponding teacher layers for the student layers.
    This is used when the teacher model has more layers than the student model.
    """
    #進來的是tuple，裡面是對應層數的Tensor
    teacher_reps = [teacher_rep.detach() for teacher_rep in teacher_reps]
    teacher_layer_num = len(teacher_reps)
    student_layer_num = len(student_reps)
    #print(teacher_reps[0].shape,student_reps[0].shape)#1,8,405,405
    if is_attn:
        # For attention layers
        if teacher_layer_num % student_layer_num != 0:
            raise ValueError(f"Teacher attention layers ({teacher_layer_num}) not divisible by student's ({student_layer_num})")
        layers_per_block = teacher_layer_num // student_layer_num
        # Select the last layer from each corresponding teacher block
        new_teacher_reps = [teacher_reps[i * layers_per_block + layers_per_block - 1] for i in range(student_layer_num)]
    else:
        # For hidden states (including embeddings)
        print(teacher_reps[0].shape,student_reps[0].shape)
        if (teacher_layer_num - 1) % (student_layer_num - 1) != 0:
            raise ValueError(f"Teacher hidden layers ({teacher_layer_num - 1}) not divisible by student's ({student_layer_num - 1})")
        layers_per_block = (teacher_layer_num - 1) // (student_layer_num - 1)
        # Select layers from the teacher at regular intervals, starting from the embeddings
        new_teacher_reps = [teacher_reps[i * layers_per_block] for i in range(student_layer_num)]

    return new_teacher_reps


def get_kd_loss(student_reps, teacher_reps, loss_fn, is_attn=False, is_img=False):
    """
    Computes the knowledge distillation loss between student and teacher representations.
    """
    kd_loss = 0.0
    if student_reps is None or teacher_reps is None:
        return kd_loss

    if is_attn:
        for student_att, teacher_att in zip(student_reps, teacher_reps):
            '''
            if student_att.shape[1] != teacher_att.shape[1]:
                min_len = min(student_att.shape[1], teacher_att.shape[1])
                student_att = student_att[:, :min_len]
                teacher_att = teacher_att[:, :min_len]
            '''
            student_att = torch.where(student_att <= -1e2, torch.zeros_like(student_att), student_att)
            teacher_att = torch.where(teacher_att <= -1e2, torch.zeros_like(teacher_att), teacher_att)
            kd_loss += loss_fn(student_att, teacher_att)
            #print("att")
            #print(student_att.shape,teacher_att.shape)
    elif is_img:
        for student_rep, teacher_rep in zip(student_reps, teacher_reps):
            teacher_rep = teacher_rep[0]
            '''
            if student_rep.shape[1] != teacher_rep.shape[1]:
                min_len = min(student_rep.shape[1], teacher_rep.shape[1])
                student_rep = student_rep[:, :min_len]
                teacher_rep = teacher_rep[:, :min_len]
            '''
            #print("is_img")
            #print(student_rep.shape,teacher_rep.shape)
            kd_loss += loss_fn(student_rep, teacher_rep)
    else: # for hidden states
        for student_rep, teacher_rep in zip(student_reps, teacher_reps):
            '''
            if student_rep.shape[1] != teacher_rep.shape[1]:
                min_len = min(student_rep.shape[1], teacher_rep.shape[1])
                student_rep = student_rep[:, :min_len]
                teacher_rep = teacher_rep[:, :min_len]
            '''
            #print("hidden states")
            #print(student_rep.shape,teacher_rep.shape)
            kd_loss += loss_fn(student_rep, teacher_rep)

    return kd_loss

class DistillSTFTrainer(SFTTrainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_hidden_states_loss(self, student_hidden_states, teacher_hidden_states,loss_fn):
        if teacher_hidden_states is None:
            return 0.0
        # Align teacher and student hidden states
        teacher_hidden_states = get_cor_teacher(teacher_hidden_states, student_hidden_states, is_attn=False)
        return get_kd_loss(student_hidden_states, teacher_hidden_states, loss_fn, is_attn=False, is_img=False)

    def compute_attentions_loss(self, student_attentions, teacher_attentions,loss_fn):
        if teacher_attentions is None:
            return 0.0
        # Align teacher and student attentions
        teacher_attentions = get_cor_teacher(teacher_attentions, student_attentions, is_attn=True)
        return get_kd_loss(student_attentions, teacher_attentions, loss_fn, is_attn=True, is_img=False)

    def compute_image_hidden_states_loss(self, student_image_hidden_states, teacher_image_hidden_states,loss_fn):
        if teacher_image_hidden_states is None or student_image_hidden_states is None:
            return 0.0
        # Vision towers are identical, no layer alignment needed
        return get_kd_loss(student_image_hidden_states, teacher_image_hidden_states, loss_fn, is_attn=False, is_img=True)

    @overrides()
    def compute_loss(self, model, inputs, return_outputs=False,num_items_in_batch=None):
        # Pop the teacher's outputs from the inputs
        teacher_hidden_states = inputs.pop("teacher_hidden_states", None)
        teacher_attentions = inputs.pop("teacher_attentions", None)
        teacher_image_hidden_states = inputs.pop("teacher_image_hidden_states", None)


        # Compute the original loss from SFTTrainer
        loss, outputs = super().compute_loss(model, inputs, return_outputs=True,num_items_in_batch=num_items_in_batch)

        # Get student's internal states
        student_hidden_states = outputs.hidden_states
        student_attentions = outputs.attentions
        student_image_hidden_states = outputs.image_hidden_states

        # Compute the distillation loss
        mse_loss = MSELoss()
        hidden_states_loss = self.compute_hidden_states_loss(student_hidden_states, teacher_hidden_states, mse_loss)
        attentions_loss = self.compute_attentions_loss(student_attentions, teacher_attentions, mse_loss)
        image_hidden_states_loss = self.compute_image_hidden_states_loss(student_image_hidden_states, teacher_image_hidden_states, mse_loss)
        print(hidden_states_loss,attentions_loss,image_hidden_states_loss)
        # Combine the losses (example: simple addition, could be weighted)
        distill_loss = hidden_states_loss + attentions_loss + image_hidden_states_loss

        # You can weigh the original loss and the distillation loss
        # Example: loss = 0.4 * loss + 0.6 * distill_loss
        loss += distill_loss

        return (loss, outputs) if return_outputs else loss


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from torch.nn.utils.rnn import pad_sequence

# This collate_fn is designed to handle the output from our new distill_dataset
def collate_fn(examples):
    texts = []
    images = []
    teacher_hidden_states_list = []
    teacher_attentions_list = []
    teacher_image_hidden_states_list = []

    # 1. Extract data from each sample
    for example in examples:
        # Standard processing for text and images
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

        # Extract teacher outputs, ensuring they are not None
        if "teacher_hidden_states" in example and example["teacher_hidden_states"] is not None:
            teacher_hidden_states_list.append(example["teacher_hidden_states"])
        if "teacher_attentions" in example and example["teacher_attentions"] is not None:
            teacher_attentions_list.append(example["teacher_attentions"])
        if "teacher_image_hidden_states" in example and example["teacher_image_hidden_states"] is not None:
            teacher_image_hidden_states_list.append(example["teacher_image_hidden_states"])

    # 2. Process and tokenize text and images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # 3. Create labels, masking where necessary
    labels = batch["input_ids"].clone()
    image_token_id = processor.tokenizer.convert_tokens_to_ids(
        processor.tokenizer.special_tokens_map["boi_token"]
    )
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100
    batch["labels"] = labels

    # 4. Pad and stack the teacher's outputs
    if teacher_hidden_states_list:
        padded_hidden_states = []
        for layer_tensors in zip(*teacher_hidden_states_list):
            # pad_sequence expects a list of tensors
            padded_layer = pad_sequence(list(layer_tensors), batch_first=True, padding_value=0.0)
            padded_hidden_states.append(padded_layer)
        batch["teacher_hidden_states"] = tuple(padded_hidden_states)

    if teacher_attentions_list:
        padded_attentions = []
        for layer_tensors in zip(*teacher_attentions_list):
            padded_layer = pad_sequence(list(layer_tensors), batch_first=True, padding_value=0.0)
            padded_attentions.append(padded_layer)
        batch["teacher_attentions"] = tuple(padded_attentions)

    if teacher_image_hidden_states_list:
        padded_image_hidden_states = []
        for layer_tensors in zip(*teacher_image_hidden_states_list):
            padded_layer = pad_sequence(list(layer_tensors), batch_first=True, padding_value=0.0)
            padded_image_hidden_states.append(padded_layer)
        batch["teacher_image_hidden_states"] = tuple(padded_image_hidden_states)

    return batch


In [6]:
from trl import SFTConfig
from torch.nn.utils.rnn import pad_sequence

args = SFTConfig(
    output_dir="model/gemma3-distill",     # directory to save and repository id
    num_train_epochs=1,                         # number of training epochs
    per_device_train_batch_size=1,              # batch size per device during training
    gradient_accumulation_steps=4,              # number of steps before performing a backward/update pass
    gradient_checkpointing=True,                # use gradient checkpointing to save memory
    optim="adamw_torch_fused",                  # use fused adamw optimizer
    logging_steps=5,                            # log every 5 steps
    save_strategy="epoch",                      # save checkpoint every epoch
    learning_rate=2e-4,                         # learning rate, based on QLoRA paper
    bf16=True,                                  # use bfloat16 precision
    max_grad_norm=0.3,                          # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                          # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",               # use constant learning rate scheduler
    report_to="tensorboard",                    # report metrics to tensorboard
    gradient_checkpointing_kwargs={
        "use_reentrant": False
    },  # use reentrant checkpointing
    dataset_text_field="",                      # need a dummy field for collator
    dataset_kwargs={"skip_prepare_dataset": True},  # important for collator
)
args.remove_unused_columns = False # important for collator

# This collate_fn is designed to handle the output from our new distill_dataset
def collate_fn(examples):
    texts = []
    images = []
    teacher_hidden_states_list = []
    teacher_attentions_list = []
    teacher_image_hidden_states_list = []

    # 1. Extract data from each sample
    for example in examples:
        # Standard processing for text and images
        image_inputs = process_vision_info(example["messages"])
        text = processor.apply_chat_template(
            example["messages"], add_generation_prompt=False, tokenize=False
        )
        texts.append(text.strip())
        images.append(image_inputs)

        # Extract teacher outputs, ensuring they are not None
        if "teacher_hidden_states" in example and example["teacher_hidden_states"] is not None:
            teacher_hidden_states_list.append(example["teacher_hidden_states"])
        if "teacher_attentions" in example and example["teacher_attentions"] is not None:
            teacher_attentions_list.append(example["teacher_attentions"])
        if "teacher_image_hidden_states" in example and example["teacher_image_hidden_states"] is not None:
            teacher_image_hidden_states_list.append(example["teacher_image_hidden_states"])

    # 2. Process and tokenize text and images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # 3. Create labels, masking where necessary
    labels = batch["input_ids"].clone()
    image_token_id = processor.tokenizer.convert_tokens_to_ids(
        processor.tokenizer.special_tokens_map["boi_token"]
    )
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100
    batch["labels"] = labels

    # 4. Pad and stack the teacher's outputs
    if teacher_hidden_states_list:
        padded_hidden_states = []
        # Transpose and pad each layer
        for layer_tensors in zip(*teacher_hidden_states_list):
            padded_layer = pad_sequence(list(layer_tensors), batch_first=True, padding_value=0.0)
            padded_hidden_states.append(padded_layer)
        batch["teacher_hidden_states"] = tuple(padded_hidden_states)

    if teacher_attentions_list:
        padded_attentions = []
        for layer_tensors in zip(*teacher_attentions_list):
            padded_layer = pad_sequence(list(layer_tensors), batch_first=True, padding_value=0.0)
            padded_attentions.append(padded_layer)
        batch["teacher_attentions"] = tuple(padded_attentions)

    if teacher_image_hidden_states_list:
        padded_image_hidden_states = []
        for layer_tensors in zip(*teacher_image_hidden_states_list):
            padded_layer = pad_sequence(list(layer_tensors), batch_first=True, padding_value=0.0)
            padded_image_hidden_states.append(padded_layer)
        batch["teacher_image_hidden_states"] = tuple(padded_image_hidden_states)

    return batch

ValueError: Your setup doesn't support bf16/gpu.

## student model

In [45]:
import copy
import torch
from transformers import AutoModelForImageTextToText, Gemma3ForConditionalGeneration,BitsAndBytesConfig

model_id = "google/gemma-3-4b-pt"
# 1) 載入「原始 Gemma」(不要量化)

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU，安裝flash attn前必須先安裝ninja，詳細:https://blog.csdn.net/lckj2009/article/details/136054392
    torch_dtype=torch.bfloat16, # What torch dtype to use, defaults to auto
    device_map="cuda:0", # Let torch decide how to load the model
)

# 1) 載入「原始 Gemma」(不要量化)
base_model_clean = AutoModelForImageTextToText.from_pretrained(
    model_id,
    **model_kwargs
)

# 2) 建立半層數設定
student_config = copy.deepcopy(base_model_clean.config)
orig_layers = student_config.text_config.num_hidden_layers
half_layers = orig_layers // 2
student_config.text_config.num_hidden_layers = half_layers
print(f"Original decoder layers: {orig_layers} -> Student: {half_layers}")

# Explicitly enable outputting hidden states and attentions for the student model
student_config.output_hidden_states = True
student_config.output_attentions = True
if hasattr(student_config, "vision_config"):
    student_config.vision_config.output_hidden_states = True

# 3) 建立學生模型（非量化）
student_model = Gemma3ForConditionalGeneration(config=student_config).to(
    base_model_clean.device, dtype=base_model_clean.dtype
)

# 4) 複製非層級權重
with torch.no_grad():
    # 視覺塔與投影器
    student_model.model.vision_tower.load_state_dict(base_model_clean.model.vision_tower.state_dict())
    student_model.model.multi_modal_projector.load_state_dict(base_model_clean.model.multi_modal_projector.state_dict())
    # 文本嵌入與最終層正規化
    student_model.model.language_model.embed_tokens.load_state_dict(
        base_model_clean.model.language_model.embed_tokens.state_dict()
    )
    student_model.model.language_model.norm.load_state_dict(
        base_model_clean.model.language_model.norm.state_dict()
    )
    # LM Head
    student_model.lm_head.load_state_dict(base_model_clean.lm_head.state_dict())

# 5) 定義兩層合一層的權重合併（逐元素平均）
def average_state_dicts(sd_a: dict, sd_b: dict, alpha: float = 0.5) -> dict:
    merged = {}
    keys = sd_a.keys() & sd_b.keys()
    for k in keys:
        ta, tb = sd_a[k], sd_b[k]
        if isinstance(ta, torch.Tensor) and isinstance(tb, torch.Tensor) and ta.shape == tb.shape:
            # 只對浮點張量做平均，其餘沿用第一個
            if torch.is_floating_point(ta) and torch.is_floating_point(tb):
                merged[k] = (1 - alpha) * ta + alpha * tb
            else:
                merged[k] = ta
        else:
            merged[k] = ta
    # 帶入 sd_a 獨有的鍵
    for k in sd_a.keys() - keys:
        merged[k] = sd_a[k]
    return merged

# 6) 以「每 2 層合成 1 層」方式拷貝到學生層
with torch.no_grad():
    for i in range(half_layers):
        t_layer_a = base_model_clean.model.language_model.layers[2 * i]
        t_layer_b = base_model_clean.model.language_model.layers[2 * i + 1]
        s_layer = student_model.model.language_model.layers[i]

        merged_sd = average_state_dicts(
            t_layer_a.state_dict(),
            t_layer_b.state_dict(),
            alpha=1,  # 可調整權重比例
        )
        s_layer.load_state_dict(merged_sd, strict=False)
with torch.no_grad():
    student_model.lm_head.weight.data = student_model.model.language_model.embed_tokens.weight.data

print("Student model built with 2-teacher-layers -> 1-student-layer merging.")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.25 GiB. GPU 0 has a total capacity of 23.56 GiB of which 144.94 MiB is free. Process 2265815 has 22.73 GiB memory in use. Of the allocated memory 22.23 GiB is allocated by PyTorch, and 79.18 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [44]:
del base_model_clean
torch.cuda.empty_cache()

## Distillation

## 讀回蒸餾資料

In [8]:
import pandas as pd, io, base64, numpy as np, torch
from torch.utils.data import Dataset

def deserialize_tensor_list(b64str):
    if not isinstance(b64str, str) or b64str == "":
        return None
    data = base64.b64decode(b64str.encode("ascii"))
    buf = io.BytesIO(data)
    with np.load(buf, allow_pickle=False) as npz:
        keys = sorted(npz.files, key=lambda k: int(k.split("_")[1]))
        tensors = [torch.from_numpy(npz[k]).to(torch.float32) for k in keys]
    return tensors

class CSVTeacherSignalsDataset(Dataset):
    def __init__(self, csv_path: str, raw_dataset):
        self.df = pd.read_csv(csv_path, compression="infer")
        self.raw_dataset = raw_dataset
        assert self.df["id"].max() < len(raw_dataset), "CSV ids exceed raw dataset length"

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        raw = self.raw_dataset[int(row["id"])]
        return {
            "messages": raw["messages"],
            "teacher_hidden_states": deserialize_tensor_list(row.get("teacher_hidden_states_b64", "")),
            "teacher_attentions": deserialize_tensor_list(row.get("teacher_attentions_b64", "")),
            "teacher_image_hidden_states": deserialize_tensor_list(row.get("teacher_image_hidden_states_b64", "")),
        }

distill_dataset = CSVTeacherSignalsDataset("dataset/distill_teacher_signals.csv", dataset)

In [9]:
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [22]:
trainer = DistillSTFTrainer(
    model=student_model,
    args=args,
    train_dataset=distill_dataset,
    processing_class=processor,
    data_collator=collate_fn,
)

In [23]:
# Start training, the model will be automatically saved to the Hub and the output directory
trainer.train()

# Save the final model again to the Hugging Face Hub
trainer.save_model(safe_serialization=False)

torch.Size([1, 405, 2560]) torch.Size([1, 405, 2560])
tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([1, 405, 2560]) torch.Size([1, 405, 2560])
tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([1, 458, 2560]) torch.Size([1, 458, 2560])
tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)
torch.Size([1, 404, 2560]) torch.Size([1, 404, 2560])
tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)


Step,Training Loss


torch.Size([1, 409, 2560]) torch.Size([1, 409, 2560])
tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>) tensor(nan, device='cuda:0', grad_fn=<AddBackward0>)


OutOfMemoryError: CUDA out of memory. Tried to allocate 1024.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 70.94 MiB is free. Process 926453 has 23.13 GiB memory in use. Of the allocated memory 21.62 GiB is allocated by PyTorch, and 1.04 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)