In [1]:
from __future__ import annotations

import os 

from torch.utils.data import random_split

from mmpfn.datasets.cbis_ddsm import CBISDDSMDataset

import os 
import torch 
import numpy as np 
import pandas as pd

from sklearn.metrics import accuracy_score, log_loss, roc_auc_score, root_mean_squared_error

from mmpfn.models.tabpfn_v2 import TabPFNClassifier
from mmpfn.models.dino_v2.models.vision_transformer import vit_base
from mmpfn.models.tabpfn_v2.constants import ModelInterfaceConfig
from mmpfn.models.tabpfn_v2.preprocessing import PreprocessorConfig
from mmpfn.scripts_finetune.finetune_tabpfn_main import fine_tune_tabpfn



In [2]:
# data_path = os.path.join(os.getenv('HOME'), "workspace/works/tabular_image/MultiModalPFN/mmpfn/data/cbis_ddsm")
data_path = os.path.join(os.getenv('HOME'), "works/research/MultiModalPFN/mmpfn/data/cbis_ddsm")

kind = 'calc'  # mass calc
image_type = 'all' # all full crop roi
test_dataset = CBISDDSMDataset(data_path=data_path, data_name=f'csv/{kind}_case_description_test_set.csv', kind=kind, image_type=image_type)
# _ = test_dataset.get_images()
# _ = test_dataset.get_embeddings(mode='test')
train_dataset = CBISDDSMDataset(data_path=data_path, data_name=f'csv/{kind}_case_description_train_set.csv', kind=kind, image_type=image_type)
# _ = train_dataset.get_images()
# _ = train_dataset.get_embeddings(mode='train')

In [3]:
accuracy_scores = []
for seed in range(5):
    torch.manual_seed(seed)
    # np.random.seed(seed)
    # print(f"Finetuning with seed: {seed}")
    
    X_train = train_dataset.x
    y_train = train_dataset.y
    X_test = test_dataset.x
    y_test = test_dataset.y

    for i in range(X_train.shape[1]):
        col = X_train[:, i]
        col[np.isnan(col)] = np.nanmin(col) - 1
    for i in range(X_test.shape[1]):
        col = X_test[:, i]
        col[np.isnan(col)] = np.nanmin(col) - 1

    torch.cuda.empty_cache()

    save_path_to_fine_tuned_model = "./finetuned_tabpfn_cbis_calc.ckpt"
    
    fine_tune_tabpfn(
        # path_to_base_model="auto",
        save_path_to_fine_tuned_model=save_path_to_fine_tuned_model,
        # Finetuning HPs
        time_limit=60,
        finetuning_config={"learning_rate": 0.00001, "batch_size": 1, "max_steps": 100},
        validation_metric="log_loss",
        # Input Data
        X_train=pd.DataFrame(X_train),
        y_train=pd.Series(y_train),
        categorical_features_index=None,
        device="cuda",  # use "cpu" if you don't have a GPU
        task_type="binary",
        # Optional
        show_training_curve=False,  # Shows a final report after finetuning.
        logger_level=0,  # Shows all logs, higher values shows less
    )

    # disables preprocessing at inference time to match fine-tuning
    no_preprocessing_inference_config = ModelInterfaceConfig(
        FINGERPRINT_FEATURE=False,
        PREPROCESS_TRANSFORMS=[PreprocessorConfig(name='none')]
    )

    # Evaluate on Test Data
    model_finetuned = TabPFNClassifier(
        model_path=save_path_to_fine_tuned_model,
        inference_config=no_preprocessing_inference_config,
        ignore_pretraining_limits=True,
    )

    clf_finetuned = model_finetuned.fit(X_train, y_train)
    acc_score = accuracy_score(y_test, clf_finetuned.predict(X_test))
    print("accuracy_score (Finetuned):", acc_score)
    accuracy_scores.append(acc_score)

Fine-tuning Steps:   1%|          | 1/100 [00:00<?, ?it/s][2025-09-17 23:50:15,221] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps:  31%|███       | 31/100 [00:08<00:18,  3.78it/s, Best Val. Loss=0.385, Best Val. Score=-0.385, Training Loss=0.411, Val. Loss=0.385, Patience=21, Utilization=0, Grad Norm=6.17][2025-09-17 23:50:23,329] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps: 101it [00:27,  3.69it/s, Best Val. Loss=0.329, Best Val. Score=-0.329, Training Loss=0.422, Val. Loss=0.329, Patience=-48, Utilization=0, Grad Norm=5.25]                         
[2025-09-17 23:50:42,029] INFO - Initial Validation Loss: 0.6356307716527124 Best Validation Loss: 0.32850282097971256 Total Steps: 101 Best Step: 100 Total Time Spent: 28.08517050743103


accuracy_score (Finetuned): 0.7361963190184049


Fine-tuning Steps:   1%|          | 1/100 [00:00<?, ?it/s][2025-09-17 23:50:42,861] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps:  31%|███       | 31/100 [00:08<00:18,  3.66it/s, Best Val. Loss=0.388, Best Val. Score=-0.388, Training Loss=0.422, Val. Loss=0.388, Patience=21, Utilization=0, Grad Norm=5.5] [2025-09-17 23:50:50,890] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps:  47%|████▋     | 47/100 [00:11<00:13,  4.07it/s, Best Val. Loss=0.365, Best Val. Score=-0.365, Training Loss=0.304, Val. Loss=0.365, Patience=6, Utilization=0, Grad Norm=4.23] [2025-09-17 23:50:54,795] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps: 101it [00:25,  3.86it/s, Best Val. Loss=0.327, Best Val. Score=-0.327, Training Loss=0.422, Val. Loss=0.327, Patience=-47, Utilization=0, Grad Norm=5.93]                         
[2025-09-17 23:51:08,575] INFO - Initial Validation Loss: 0.6356307716527124 Best

accuracy_score (Finetuned): 0.7269938650306749


Fine-tuning Steps:   1%|          | 1/100 [00:00<?, ?it/s][2025-09-17 23:51:09,325] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps:  31%|███       | 31/100 [00:07<00:18,  3.73it/s, Best Val. Loss=0.386, Best Val. Score=-0.386, Training Loss=0.413, Val. Loss=0.386, Patience=21, Utilization=0, Grad Norm=5.93][2025-09-17 23:51:17,174] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps: 101it [00:26,  3.81it/s, Best Val. Loss=0.329, Best Val. Score=-0.329, Training Loss=0.411, Val. Loss=0.329, Patience=-48, Utilization=0, Grad Norm=4.67]                         
[2025-09-17 23:51:35,393] INFO - Initial Validation Loss: 0.6356307716527124 Best Validation Loss: 0.3290179525979057 Total Steps: 101 Best Step: 99 Total Time Spent: 26.497071743011475


accuracy_score (Finetuned): 0.7361963190184049


Fine-tuning Steps:   1%|          | 1/100 [00:00<?, ?it/s][2025-09-17 23:51:36,137] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps:  31%|███       | 31/100 [00:07<00:18,  3.80it/s, Best Val. Loss=0.386, Best Val. Score=-0.386, Training Loss=0.417, Val. Loss=0.386, Patience=21, Utilization=0, Grad Norm=5.57][2025-09-17 23:51:44,051] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps: 101it [00:26,  3.80it/s, Best Val. Loss=0.328, Best Val. Score=-0.328, Training Loss=0.417, Val. Loss=0.328, Patience=-48, Utilization=0, Grad Norm=4.7]                          
[2025-09-17 23:52:02,271] INFO - Initial Validation Loss: 0.6356307716527124 Best Validation Loss: 0.3278344809926605 Total Steps: 101 Best Step: 100 Total Time Spent: 26.55355715751648


accuracy_score (Finetuned): 0.7300613496932515


Fine-tuning Steps:   1%|          | 1/100 [00:00<?, ?it/s][2025-09-17 23:52:03,061] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps:  31%|███       | 31/100 [00:08<00:18,  3.77it/s, Best Val. Loss=0.386, Best Val. Score=-0.386, Training Loss=0.417, Val. Loss=0.386, Patience=21, Utilization=0, Grad Norm=5.61][2025-09-17 23:52:11,083] INFO - 
Optimizer step skipped due to NaNs/infs in grad scaling.
Fine-tuning Steps: 101it [00:27,  3.60it/s, Best Val. Loss=0.329, Best Val. Score=-0.329, Training Loss=0.415, Val. Loss=0.329, Patience=-48, Utilization=0, Grad Norm=4.69]                         
[2025-09-17 23:52:30,588] INFO - Initial Validation Loss: 0.6356307716527124 Best Validation Loss: 0.32891966359035546 Total Steps: 101 Best Step: 100 Total Time Spent: 27.963973999023438


accuracy_score (Finetuned): 0.7361963190184049


In [4]:
# get mean and std of accuracy scores
mean_accuracy = np.mean(accuracy_scores)
std_accuracy = np.std(accuracy_scores)
print("Mean Accuracy:", mean_accuracy)
print("Std Accuracy:", std_accuracy)

Mean Accuracy: 0.7331288343558283
Std Accuracy: 0.0038800952885501365


In [5]:
# Mean Accuracy: 0.7527607361963191
# Std Accuracy: 0.008366982636187638
# Mean Accuracy: 0.7533742331288342
# Std Accuracy: 0.00858895705521472