[![Roboflow Notebooks](https://media.roboflow.com/notebooks/template/bannertest2-2.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672932710194)](https://github.com/roboflow/notebooks)

# Fine-tune PaliGemma2 on Object Detection Dataset

---

[![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/google-research/big_vision/blob/main/big_vision/configs/proj/paligemma/README.md)
[![arXiv](https://img.shields.io/badge/arXiv-2412.03555-b31b1b.svg)](https://arxiv.org/abs/2412.03555)

PaliGemma 2 is built by combining the SigLIP-So400m vision encoder with the more recent and capable language models from the Gemma 2 family.

![PaliGemma2 Figure.1](https://storage.googleapis.com/com-roboflow-marketing/notebooks/examples/paligemma2-1.png)

The authors use a 3-stage training approach similar to the original PaliGemma. In stage 1, they combine the pretrained vision and language model components and train them jointly on a multimodal task mixture. In stage 2, they train the models at higher resolutions of 448px^2 and 896px^2. In stage 3, they fine-tune the models on the target transfer tasks.

PaliGemma 2 models outperform the original PaliGemma at the same resolution and model size. Increasing the model size and resolution generally improves performance across a wide range of tasks, but the benefits differ depending on the task. Some tasks benefit more from increased resolution, while others benefit more from a larger language model.

![PaliGemma2 Figure.2](https://storage.googleapis.com/com-roboflow-marketing/notebooks/examples/paligemma2-2.png)

Notebook requires A100 with 40GB of VRAM to train.

## Setup

### Configure your API keys

To fine-tune PaliGemma2, you need to provide your HuggingFace Token and Roboflow API key. Follow these steps:

- Open your [`HuggingFace Settings`](https://huggingface.co/settings) page. Click `Access Tokens` then `New Token` to generate new token.
- Go to your [`Roboflow Settings`](https://app.roboflow.com/settings/api) page. Click `Copy`. This will place your private key in the clipboard.
- In Colab, go to the left pane and click on `Secrets` (🔑).
    - Store HuggingFace Access Token under the name `HF_TOKEN`.
    - Store Roboflow API Key under the name `ROBOFLOW_API_KEY`.

### Select the runtime

Let's make sure that we have access to GPU. We can use `nvidia-smi` command to do that. In case of any problems navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `T4 GPU`, and then click `Save`.

In [1]:
!nvidia-smi
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0,1"   #,1"

Fri Oct 24 08:50:14 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 3090        On  |   00000000:43:00.0 Off |                  N/A |
| 30%   30C    P8             29W /  350W |       1MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

### Download dataset from Roboflow Universe

To fine-tune PaliGemma2, prepare your dataset in JSONL format. You can use Roboflow to easily convert any dataset into this format.

In [3]:
#!pip install -q peft bitsandbytes transformers==4.47.0 tf-keras
!rsync -a --progress /data/lmbraid19/argusm/datasets/indoorCVPR_09.tar /tmp/ && mkdir -p /tmp/indoorCVPR && tar -xf /tmp/indoorCVPR_09.tar -C /tmp/indoorCVPR
!rsync -a --progress /work/dlclarge2/zhangj-zhangj-CFM/data/training2 /tmp/
!file /tmp/indoorCVPR
!file /tmp/training2

sending incremental file list
sending incremental file list
/tmp/indoorCVPR: directory
/tmp/training2: directory


**NOTE:** Let's read the first few lines of the annotation file and examine the dataset format.

In [4]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
from cvla.data_loader_h5 import H5Dataset
from cvla.data_loader_jsonl import JSONLDataset
from cvla.data_augmentations import augment_image_rgb, RandomizeBackgrounds
from cvla.data_augmentations import complexify_text, DepthAugmentation
from cvla.data_loader_images import ImageFolderDataset
from torchvision import transforms
from torch.utils.data import random_split
import torch
import random

model_location = Path("/data/lmbraid19/argusm/models")
dataset_location = Path("/tmp/training2")

bg_image_dataset = ImageFolderDataset("/tmp/indoorCVPR/Images", transform=transforms.RandomResizedCrop((448,448)))
randomize_background = RandomizeBackgrounds(p=0.2, background_images=bg_image_dataset)
augment_depth = DepthAugmentation(depth_range=(25, 100), max_delta_depth=35)

full_dataset = H5Dataset(
    dataset_location,
    augment_rgb=augment_image_rgb,
    augment_text=complexify_text,
    augment_depth=augment_depth,
    return_depth=False,
    action_encoder="xyzrotvec-cam-512xy",
)

#image, sample = full_dataset[1]
#print(sample["rotation_labels"] )

# 手动定义验证集大小
val_size = 1000  # 固定1000条
train_size = len(full_dataset) - val_size

generator = torch.Generator().manual_seed(42)
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)

val_indices_small = random.sample(range(len(val_dataset)), 200)
val_dataset_small = torch.utils.data.Subset(val_dataset, val_indices_small)

print(f"Total samples: {len(full_dataset)} | Train: {len(train_dataset)} | Val: {len(val_dataset)}| Smallv:{len(val_dataset_small)}")


'''
train_dataset = H5Dataset(dataset_location, augment_rgb=augment_image_rgb, augment_text=complexify_text,
                          augment_depth=augment_depth, return_depth=True,action_encoder="xyzrotvec-cam-512xy")
#, augment_rgbds=randomize_background

print("dataset_location:", dataset_location,"samples:", len(train_dataset))
'''

  import pynvml  # type: ignore[import]
  warn("Failed to find system libvulkan. Fallback to SAPIEN builtin libvulkan.")


Total samples: 88244 | Train: 87244 | Val: 1000| Smallv:200


'\ntrain_dataset = H5Dataset(dataset_location, augment_rgb=augment_image_rgb, augment_text=complexify_text,\n                          augment_depth=augment_depth, return_depth=True,action_encoder="xyzrotvec-cam-512xy")\n#, augment_rgbds=randomize_background\n\nprint("dataset_location:", dataset_location,"samples:", len(train_dataset))\n'

### Set up and test data loaders

In [5]:
from cvla.utils_vis import render_example
import matplotlib.pyplot as plt
from cvla.utils_traj_tokens import getActionEncInstance

enc = getActionEncInstance("xyzrotvec-cam-512xy")
num_samples = 3*2
html_imgs = ""
for i in range(num_samples):
    image, sample = train_dataset[i]
    prefix = sample["prefix"]
    html_imgs += render_example(image[0], label=sample["suffix"], enc=enc, text=prefix, camera=sample["camera"])
    html_imgs += render_example(image[1], label=sample["suffix"], enc=enc, text=prefix, camera=sample["camera"])

plot_images = True
if plot_images:
    from IPython.display import display, HTML
    display(HTML(html_imgs))
    

### Load PaliGemma2 model

**NOTE:** PaliGemma2 offers 9 pre-trained models with sizes of `3B`, `10B`, and `28B` parameters, and resolutions of `224`, `448`, and `896` pixels. In this tutorial, I'll be using the [`google/paligemma2-3b-pt-448`](https://huggingface.co/google/paligemma2-3b-pt-448) checkpoint. Resolution has a key impact on the mAP of the trained model, and it seems that `448` offers the most optimal balance between performance and compute resources required to train the model.

In [6]:
# from huggingface_hub import notebook_login
# notebook_login()

In [6]:
from transformers import PaliGemmaProcessor, PaliGemmaForConditionalGeneration
import torch
import transformers

#transformers.utils.logging.set_verbosity_error()

# setting device on GPU if available, else CPU
print("cuda visible devices:", os.environ["CUDA_VISIBLE_DEVICES"])
devices_good = sorted((int(x) for x in os.environ["CUDA_VISIBLE_DEVICES"].split(",")))
DEVICE = torch.device('cuda')
print(DEVICE)
print('Using device:', DEVICE)
print("Good devices", devices_good)

TORCH_DTYPE = torch.bfloat16
# use checkpoint
#LOCAL_CHECKPOINT = "/data/lmbraid19/argusm/models/_text_lr3e-05xyzrotvec-cam-512xy256d_2025-04-23_12-03-48/checkpoint-4687"

#fine-tune directly on paligemma2
MODEL_NAME = "google/paligemma2-3b-pt-224"

processor = PaliGemmaProcessor.from_pretrained("google/paligemma2-3b-pt-224")
base_model = PaliGemmaForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype=TORCH_DTYPE,
    device_map=None,
    attn_implementation="eager"
)
#.to("cuda") 
tokenizer = processor.tokenizer


cuda visible devices: 0,1
cuda
Using device: cuda
Good devices [0, 1]


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
import torch.nn as nn
class PaliGemmaWithAuxRotation(nn.Module):
    """
    Wraps PaliGemma2 for dual-view training:
    - side_inputs: main task (trajectory prediction)
    - top_inputs: auxiliary task (rotation prediction)
    """
    def __init__(self, base_model, hidden_dim=2304):
        super().__init__()
        self.base = base_model
        self.hidden_dim = hidden_dim

        # 用于 top-view 图像特征的轻量旋转预测头
        self.aux_head = nn.Sequential(
            nn.Linear(hidden_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 3)  # rotvec (x, y, z)
        )

    def forward(self, side_inputs=None, top_inputs=None, **kwargs):
        device = next(self.base.parameters()).device

        def move_to_device(batch):
            return {
                k: v.to(device) if hasattr(v, "to") else v
                for k, v in batch.items()
            }
    
        if side_inputs is not None:
            side_inputs = move_to_device(side_inputs)
        if top_inputs is not None:
            top_inputs = move_to_device(top_inputs)
            
        # ============ 主任务（side view） ============
        if not self.training:  # eval 阶段
            with torch.no_grad():
                side_outputs = self.base(**side_inputs)
        else:
            side_outputs = self.base(**side_inputs)
        main_loss = getattr(side_outputs, "loss", None)

        # ============ 辅助任务（top view） ============
        if top_inputs is not None:
            with torch.no_grad():  # 不计算梯度，节省显存
                top_outputs = self.base(**top_inputs, output_hidden_states=True)
            # 提取视觉特征 (平均 pooling)
            hidden = top_outputs.image_hidden_states.mean(dim=1)
            pred_rot = self.aux_head(hidden)
        else:
            pred_rot = None

        # ============ 合并输出 ============
        # 转成 dict，兼容 transformers 输出格式
        if not isinstance(side_outputs, dict):
            side_outputs = side_outputs.__dict__

        side_outputs["pred_rotations"] = pred_rot
        side_outputs["main_loss"] = main_loss

        return side_outputs
        
    def generate(self, *args, **kwargs):
        return self.base.generate(*args, **kwargs)
        
model = PaliGemmaWithAuxRotation(base_model)


In [8]:
import random
def augment_suffix(suffix):
    parts = suffix.split(' ; ')
    random.shuffle(parts)
    return ' ; '.join(parts)

In [9]:
def collate_fn_multi_view(batch):
    # batch: [( [side_img, top_img], label_dict ), ...]
    images, labels = zip(*batch)

    # 拆出 side / top
    side_images = [img_list[0] for img_list in images]
    top_images  = [img_list[1] for img_list in images]

    # 拼接前缀和后缀（和你原来一致）
    prefixes = ["<image>" + label["prefix"] for label in labels]
    suffixes = [augment_suffix(label["suffix"]) for label in labels]

    # 侧视图输入（主任务）
    side_inputs = processor(
        text=prefixes,
        images=side_images,
        return_tensors="pt",
        suffix=suffixes,
        padding="longest"
    ).to(TORCH_DTYPE)
    
    
    # 俯视图输入（辅助任务）
    # 注意这里不加 suffix，因为 top_view 只负责视觉特征提取
    top_inputs = processor(
        text=["<image>" for _ in labels],  # 可以加提示语，例如 "top view image"
        images=top_images,
        return_tensors="pt",
        padding="longest"
    ).to(TORCH_DTYPE)

    batch_out = {
        "side_inputs": side_inputs,
        "top_inputs": top_inputs,
    }

    # 如果 label 中有旋转监督
    if "rotation_labels" in labels[0]:
        batch_out["rotation_labels"] = torch.stack(
            [label["rotation_labels"] for label in labels]
        )

    return batch_out


In [10]:
from cvla.utils_eval import Evaluator
from transformers import Seq2SeqTrainingArguments
from transformers import Seq2SeqTrainer
import torch
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from math import ceil

TRAIN_EXAMPLES = len(train_dataset)
BATCH_SIZE = 32
BATCH_SIZE_DEV = 2
GRAD_ACCUM = int(round(BATCH_SIZE / BATCH_SIZE_DEV))
TRAIN_STEPS = TRAIN_EXAMPLES // BATCH_SIZE
SEQLEN = 12
#EVAL_STEPS = 3
EVAL_STEPS = 200
SAVE_LIMIT = 5
LOGGING_STEPS = 10

run_name = "_topview_70000_toploss_value"
new_model_location = Path("/work/dlclarge2/zhangj-zhangj-CFM/models")
save_path = new_model_location / (str(Path(dataset_location).stem) + run_name)

print("save_path", save_path)
print("TRAIN_STEPS",TRAIN_STEPS)
print("GRAD_ACCUM", GRAD_ACCUM)

writer = SummaryWriter(log_dir=str(save_path / "tb_logs"))

import torch
import torch.nn.functional as F
from math import ceil
from tqdm import tqdm

class RotationAwareCustomTrainer(Seq2SeqTrainer):
    """
    Trainer that:
      - computes normal LM loss
      - adds auxiliary rotation loss
      - evaluates both sequence accuracy and rotation prediction accuracy
    """

    def __init__(self, lambda_rot=0.3, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.lambda_rot = lambda_rot

   

    # ============================
    #       TRAINING PHASE
    # ============================
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """
        Training step: compute combined loss = LM loss + λ * rotation loss
        """
        outputs = model(**inputs)
        main_loss = getattr(outputs, "loss", None)

        if main_loss is None:
            raise ValueError("Model outputs do not contain 'loss' field.")

        rot_loss = torch.tensor(0.0, device=main_loss.device)
        if "rotation_labels" in inputs and "pred_rotations" in outputs:
            pred_rot = outputs["pred_rotations"]
            true_rot = inputs["rotation_labels"].to(pred_rot.device)
            #rot_loss = F.l1_loss(pred_rot, true_rot)
            rot_loss = 1 - F.cosine_similarity(pred_rot, true_rot, dim=-1).mean()
            total_loss = main_loss + self.lambda_rot * rot_loss
        else:
            total_loss = main_loss

        # Logging (every logging_steps)
        if self.state.global_step % self.args.logging_steps == 0:
            writer.add_scalar("train/loss_total", total_loss.item(), self.state.global_step)
            writer.add_scalar("train/loss_main", main_loss.item(), self.state.global_step)
            writer.add_scalar("train/loss_rot", rot_loss.item(), self.state.global_step)
            writer.add_scalar("train/lr", self.optimizer.param_groups[0]["lr"], self.state.global_step)

        return (total_loss, outputs) if return_outputs else total_loss
    
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        """
        Overridden evaluation that generates predictions textually
        and computes spatial metrics via Evaluator.
        """
        self.model.eval()
        dataset = eval_dataset or self.eval_dataset
    
        # helper: unwrap nested Subsets to access H5Dataset
        def unwrap_dataset(dset):
            while hasattr(dset, "dataset"):
                dset = dset.dataset
            return dset
    
        base_dataset = unwrap_dataset(dataset)
        camera = dataset[0][1]["camera"]
    
        evaluator = Evaluator(
            getActionEncInstance("xyzrotvec-cam-512xy"),
            camera_fixed=camera,
            encoder_labels=base_dataset.action_encoder,
        )
    
        eval_batch_size = self.args.per_device_eval_batch_size
        test_samples = min(len(dataset), 200)
        device = next(self.model.parameters()).device
    
        for start_idx in tqdm(
            range(0, test_samples, eval_batch_size),
            total=ceil(test_samples / eval_batch_size),
        ):
            batch_i = range(start_idx, min(start_idx + eval_batch_size, test_samples))
            batch = [dataset[i] for i in batch_i]
    
            # ✅ 显式使用 collate_fn_multi_view
            inputs = collate_fn_multi_view(batch)
    
            # ---- move tensors to device ----
            for key in ["side_inputs", "top_inputs"]:
                if key in inputs:
                    inputs[key] = {
                        k: v.to(device) if isinstance(v, torch.Tensor) else v
                        for k, v in inputs[key].items()
                    }
            if "rotation_labels" in inputs:
                inputs["rotation_labels"] = inputs["rotation_labels"].to(device)
    
            side_inputs = inputs["side_inputs"]
    
            # ✅ 提取 prefix 长度、labels
            prefix_length = side_inputs["input_ids"].shape[-1]
            labels = side_inputs.get("labels", None)
    
            with torch.inference_mode():
                generation = self.model.generate(
                    **side_inputs,
                    max_new_tokens=13,
                    do_sample=False,
                    use_cache=False,
                )
    
                decoded = [
                    self.processing_class.decode(x[prefix_length:], skip_special_tokens=True)
                    for x in generation
                ]
    
                decoded_labels = []
                if labels is not None:
                    decoded_labels = [
                        self.processing_class.decode(
                            [t for t in x.tolist() if t >= 0],
                            skip_special_tokens=True,
                        )
                        for x in labels
                    ]
    
            # optional debug
            if start_idx == 0:
                print("decoded[0]:", decoded[0] if decoded else None)
                print("decoded_label[0]:", decoded_labels[0] if decoded_labels else None)
    
            for pred, label in zip(decoded, decoded_labels):
                evaluator.evaluate(pred, label, camera=camera)
    
        stats = evaluator.report_stats()
        metrics = {f"{metric_key_prefix}_{k}": v for k, v in stats.items()}
    
        # log to TensorBoard
        for k, v in metrics.items():
            writer.add_scalar(k, v, self.state.global_step)
    
        self.log(metrics)
        return metrics
    


save_path /work/dlclarge2/zhangj-zhangj-CFM/models/training2_topview_70000_toploss_value
TRAIN_STEPS 2726
GRAD_ACCUM 16


In [11]:
'''
    # ============================
    #       EVALUATION PHASE
    # ============================
    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix="eval"):
        """
        Evaluate both language output quality (trajectory prediction)
        and rotation regression accuracy (auxiliary task).
        """
        self.model.eval()
        dataset = eval_dataset or self.eval_dataset

        # --- unwrap nested Subset ---
        def unwrap_dataset(dset):
            while hasattr(dset, "dataset"):
                dset = dset.dataset
            return dset

        base_dataset = unwrap_dataset(dataset)
        camera = dataset[0][1]["camera"]

        evaluator = Evaluator(
            getActionEncInstance("xyzrotvec-cam-512xy"),
            camera_fixed=camera,
            encoder_labels=base_dataset.action_encoder,
        )

        eval_batch_size = self.args.per_device_eval_batch_size
        test_samples = min(len(dataset), 200)
        device = next(self.model.parameters()).device

        total_rot_loss = 0.0
        total_rot_error = 0.0
        rot_count = 0

        for start_idx in tqdm(range(0, test_samples, eval_batch_size),
                              total=ceil(test_samples / eval_batch_size)):
            batch_i = range(start_idx, min(start_idx + eval_batch_size, test_samples))
            batch = [dataset[i] for i in batch_i]
            inputs = self.data_collator(batch)
            inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}

            # forward pass
            with torch.inference_mode():
                outputs = self.model(**inputs)

                # ---------- 1️⃣ rotation evaluation ----------
                if "rotation_labels" in inputs and "pred_rotations" in outputs:
                    pred_rot = outputs["pred_rotations"]
                    true_rot = inputs["rotation_labels"].to(pred_rot.device)
                    batch_rot_loss = F.l1_loss(pred_rot, true_rot, reduction="mean").item()
                    total_rot_loss += batch_rot_loss * pred_rot.size(0)

                    # compute angular error in degrees
                    # Convert rotvec difference norm to degrees
                    diff = pred_rot - true_rot
                    batch_rot_error = torch.norm(diff, dim=-1).mean().item() * (180.0 / torch.pi)
                    total_rot_error += batch_rot_error * pred_rot.size(0)
                    rot_count += pred_rot.size(0)

                # ---------- 2️⃣ textual generation evaluation ----------
                prefix_length = inputs["input_ids"].shape[-1]
                generation = self.model.generate(
                    **inputs,
                    max_new_tokens=13,
                    do_sample=False,
                    use_cache=False
                )
                decoded = [
                    self.processing_class.decode(x[prefix_length:], skip_special_tokens=True)
                    for x in generation
                ]
                decoded_labels = [
                    self.processing_class.decode([t for t in x.tolist() if t >= 0], skip_special_tokens=True)
                    for x in inputs["labels"]
                ]

            # ---------- 3️⃣ spatial evaluation ----------
            for pred, label in zip(decoded, decoded_labels):
                evaluator.evaluate(pred, label, camera=camera)

        # --- Compute aggregate metrics ---
        stats = evaluator.report_stats()
        metrics = {f"{metric_key_prefix}_{k}": v for k, v in stats.items()}

        if rot_count > 0:
            metrics[f"{metric_key_prefix}_rot_loss"] = total_rot_loss / rot_count
            metrics[f"{metric_key_prefix}_rot_error_deg"] = total_rot_error / rot_count

        # --- Log everything ---
        for k, v in metrics.items():
            writer.add_scalar(k, v, self.state.global_step)

        self.log(metrics)
        return metrics
'''




## Fine-tune with JAX settings

In [12]:

for param in model.base.vision_tower.parameters():
    param.requires_grad = False

for param in model.base.multi_modal_projector.parameters():
    param.requires_grad = False

# ✅ 然后冻结除 self_attn 外的 transformer 层
for name, param in model.base.named_parameters():
    if "self_attn" in name:
        param.requires_grad = True
    else:
        param.requires_grad = False
        
for param in model.aux_head.parameters():
    param.requires_grad = True

args_jax = Seq2SeqTrainingArguments(
    max_steps=TRAIN_STEPS,
    per_device_train_batch_size=BATCH_SIZE_DEV,
    gradient_accumulation_steps=GRAD_ACCUM,
    learning_rate=3e-5,
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    generation_max_length=SEQLEN,
    logging_steps=LOGGING_STEPS,
    optim="adafactor",
    evaluation_strategy="steps",
    eval_steps=EVAL_STEPS,
    save_strategy="steps",
    save_steps=EVAL_STEPS,
    save_total_limit=SAVE_LIMIT,
    load_best_model_at_end=True,
    metric_for_best_model="cart_l1",
    greater_is_better=False,
    bf16=True,
    output_dir=save_path,
    report_to=["tensorboard"],
    dataloader_num_workers=4,
    remove_unused_columns=False,
)

trainer = RotationAwareCustomTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=val_dataset_small,
    data_collator=collate_fn_multi_view,  # 👈 替换为双视角 collate
    args=args_jax,
    lambda_rot=0.3,  # 旋转 loss 权重
)


  super().__init__(*args, **kwargs)


In [13]:
'''
#only when recover from last time training
trainer.args.save_safetensors = False
last_checkpoint = "/work/dlclarge2/zhangj-zhangj-CFM/models/training2_topview_70000_basetop_loss/checkpoint-200"
trainer.train(resume_from_checkpoint=last_checkpoint)
trainer.save_model(str(save_path / "final_checkpoint"))
writer.close()
print("✅ Training completed successfully with Evaluator-based validation.")
'''

ValueError: Can't find a valid checkpoint at /work/dlclarge2/zhangj-zhangj-CFM/models/training2_topview_70000_basetop_loss/checkpoint-200

In [None]:
trainer.args.save_safetensors = False
trainer.train()
trainer.save_model(str(save_path / "final_checkpoint"))
writer.close()
print("✅ Training completed successfully with Evaluator-based validation.")


In [None]:
for key, value in inputs.items():
    if torch.is_tensor(value):
        inputs[key] = value.to(DEVICE)

In [None]:
print("Model device:", next(model.parameters()).device)
for k, v in inputs.items():
    if torch.is_tensor(v):
        print(f"  {k}: {v.device}")


In [None]:
print(next(model.parameters()).device)
print({k: v.device for k, v in inputs.items() if torch.is_tensor(v)})