In [1]:
import torch
import torch.nn as nn
from models.classifier import CustomClassifier
from models.fusion_module import FusionModel, LitMMFusionModel , UnimodalModel
from encoders.image_encoders import DenseNet3D, ResNet50 , VIT
from encoders.tabular_encoders import TabEncoderMLP, ClinicalBERTEncoder
from datasets.datasets import MMClassificationCollator
from pytorch_lightning import Trainer
from transformers import AutoImageProcessor
from preprocessing.MRI_processing import GetMiddleSlice
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present

from preprocessing.text_processing import get_serialize_data_processing
from datasets.datasets import get_adni_dataset_new, data_loader, MMClassificationCollator
from monai.transforms import (
    Compose,
    LoadImage,
    ScaleIntensity,
    EnsureChannelFirst,
    Resize,
    EnsureType
)

In [2]:
torch.cuda.is_available()

True

# Data loading

In [2]:
seed = 42
yaml_db_path="ADNI_dataset.yaml"
import yaml

with open(yaml_db_path, 'r') as f:
    dataset_config = yaml.safe_load(f)
print(dataset_config)


class SimpleConfig:
        def __init__(self):
            self.batch_size = 1 
            self.num_workers = 4
            self.pin_memory = True
            self.persistent_workers = True
            self.dataset = dataset_config
            
        class model:
            weight_samples = False
    
cfg = SimpleConfig()

# Load dataset using the same pattern as data_utils.py
dataset_args = dict(dataset_config)


X_train, y_train, X_val, y_val, X_test, y_test, class_names, le = get_adni_dataset_new(**dataset_args,seed=seed)


# 1. Set up the transforms
transforms = {
    "transform_fun": Compose([
        LoadImage(image_only=True),
        ScaleIntensity(),
        EnsureChannelFirst(),
        Resize(spatial_size=[128, 128, 128]),  # Depth, Height, Width
        EnsureType()
    ]),
    "params": {},
    "data_type": "nii.gz"
}


collate_fn = MMClassificationCollator( 
    image_processing=transforms,
    tabular_processing=None)


test_dataloader = data_loader(X_test, y_test,collate_fn=collate_fn, config=cfg, shuffle=True)


{'dataset_name': 'ADNI MRI dataset New', 'data_path': '/gpfs/projects/acad/mmfusion/datasets/adni/ADNIDenoise/AD_CN_clinical_duplicates_aligned.csv', 'image_column': 'nii_path', 'image_type': 'nii.gz', 'eeg_column': None, 'text_columns': None, 'target_column': 'Group', 'subject_column': 'Subject', 'train_val_size': 0.1, 'train_test_size': 0.1, 'normalize': True, 'make_dummies': False, 'drop_duplicates': False, 'patient_based_split': True, 'nunique_cat_values': [2, 2, 5, 2, 7, 2, 3, 3, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], 'cat_cols': ['PTRACCAT', 'PTGENDER', 'PTHAND', 'PTMARRY', 'PTNOTRT', 'PTTLANG', 'MH14ALCH', 'MH17MALI', 'MH16SMOK', 'MH15DRUG', 'MH4CARD', 'MHPSYCH', 'MH2NEURL', 'MH6HEPAT', 'MH12RENA', 'DSPANFOR', 'DSPANBAC', 'CDGLOBAL', 'BCFAQ', 'BCDEPRES'], 'cat_cols_dummies': ['PTRACCAT_2', 'PTRACCAT_4', 'PTRACCAT_5', 'PTRACCAT_6', 'PTGENDER_1.0', 'PTGENDER_2.0', 'PTHAND_1.0', 'PTHAND_2.0', 'PTMARRY_1.0', 'PTMARRY_2.0', 'PTMARRY_3.0', 'PTMARRY_4.0', 'PTMARRY_5.0', 'PTNOTRT_0.

# 3D MRI models (best one for MRI modality)

In [7]:
densenet3d_chkpt_patient = "/gpfs/projects/acad/mmfusion/projects/multimodal_framework/best_models/best_model_densenet3D_epoch=14-step=3510.ckpt"

image_encoder = DenseNet3D(
            freeze=False,
            include_head=False
          )

classifier = CustomClassifier(
    hidden_dim=0,
    activation_fun=nn.ReLU(),
    num_class=2, 
    task="multiclass",  
    dropout_rate=0
)

encoders = {"image": image_encoder}
unimodal_model = UnimodalModel(
    encoders=encoders,
    classifier=classifier
)
dummy_input = next(iter(test_dataloader))

unimodal_model(dummy_input)
ckpt = torch.load(densenet3d_chkpt_patient, map_location="cpu")
state_dict = ckpt["state_dict"]
consume_prefix_in_state_dict_if_present(state_dict, "model.")
unimodal_model.load_state_dict(state_dict)

Initializing layers with input dimension: 1024


  ckpt = torch.load(densenet3d_chkpt_patient, map_location="cpu")


<All keys matched successfully>

In [8]:
unimodal_model.encoders['image'](next(iter(test_dataloader))['image']).shape

torch.Size([1, 1024])

# 2D MRI models Resnet50 (coronal View)

## Data Loading

In [9]:
# 1. Set up the transforms
transforms = {
    "transform_fun": Compose([
        GetMiddleSlice(axis="Coronal"),
        ScaleIntensity(),
        EnsureType()
    ]),
    "params": {},
    "data_type": "nii.gz"
}

collate_fn = MMClassificationCollator( 
    image_processing=transforms,
    tabular_processing=None)


test_dataloader = data_loader(X_test, y_test,collate_fn=collate_fn, config=cfg, shuffle=False)




## Model Loading

In [None]:
resnet50_ckpt = "/gpfs/projects/acad/mmfusion/projects/multimodal_framework/best_models/best_model_resnet50_epoch=9-step=590.ckpt"

image_encoder = ResNet50(
            freeze=False,
            include_head=False
          )

classifier = CustomClassifier(
    hidden_dim=0,  
    activation_fun=nn.ReLU(),
    num_class=2, 
    task="multiclass", 
    dropout_rate=0
)

encoders = {"image": image_encoder}
unimodal_model = UnimodalModel(
    encoders=encoders,
    classifier=classifier
)


dummy_input = next(iter(test_dataloader))

unimodal_model(dummy_input)

ckpt = torch.load(resnet50_ckpt, map_location="cpu")
state_dict = ckpt["state_dict"]
consume_prefix_in_state_dict_if_present(state_dict, "model.")
unimodal_model.load_state_dict(state_dict)


Initializing layers with input dimension: 2048


  ckpt = torch.load(resnet50_ckpt, map_location="cpu")


<All keys matched successfully>

# 2D MRI Model VIT (coronal view)

## Data Loading

In [3]:
cfg.batch_size = 2
# 1. Set up the transforms
transforms = {
    "transform_fun": Compose([
        GetMiddleSlice(axis="Coronal"),
        ScaleIntensity(),
        EnsureType(),
        AutoImageProcessor.from_pretrained(
            pretrained_model_name_or_path="google/vit-base-patch16-224-in21k",
            do_rescale=False,
            use_fast=True,
            return_tensors="pt"
        )
    ]),
    "params": {},
    "data_type": "nii.gz"
}

collate_fn = MMClassificationCollator( 
    image_processing=transforms,
    tabular_processing=None)


test_dataloader = data_loader(X_test, y_test,collate_fn=collate_fn, config=cfg, shuffle=False)



## Model Loading

In [6]:
vit_model_ckpt="/gpfs/projects/acad/mmfusion/projects/multimodal_framework/best_models/best_model_VIT_epoch=4-step=295.ckpt"

image_encoder = VIT(
            freeze=False,
            include_head=False
          )

classifier = CustomClassifier(
    hidden_dim=0,  
    activation_fun=nn.ReLU(),
    num_class=2,  
    task="multiclass",  
    dropout_rate=0
)

encoders = {"image": image_encoder}
unimodal_model = UnimodalModel(
    encoders=encoders,
    classifier=classifier
)

dummy_input = next(iter(test_dataloader))

unimodal_model(dummy_input)

ckpt = torch.load(vit_model_ckpt, map_location="cpu")
state_dict = ckpt["state_dict"]
consume_prefix_in_state_dict_if_present(state_dict, "model.")
unimodal_model.load_state_dict(state_dict)

Initializing layers with input dimension: 768


  ckpt = torch.load(vit_model_ckpt, map_location="cpu")


<All keys matched successfully>

# Tabular modality Model (MLP)

## Data Loading

In [9]:
seed = 42
yaml_db_path="ADNI_dataset.yaml"
import yaml

with open(yaml_db_path, 'r') as f:
    dataset_config = yaml.safe_load(f)

dataset_config["drop_duplicates"] = True
class SimpleConfig:
        def __init__(self):
            self.batch_size = 4 
            self.num_workers = 4
            self.pin_memory = True
            self.persistent_workers = True
            self.dataset = dataset_config
            
        class model:
            weight_samples = False
    
cfg = SimpleConfig()

# Load dataset using the same pattern as data_utils.py
dataset_args = dict(dataset_config)


X_train, y_train, X_val, y_val, X_test, y_test, class_names, le = get_adni_dataset_new(**dataset_args,seed=seed)



collate_fn = MMClassificationCollator( 
    image_processing=None,
    tabular_processing=True)


test_dataloader = data_loader(X_test, y_test,collate_fn=collate_fn, config=cfg, shuffle=True)

## Model Loading

In [10]:
model_ckpt= "/gpfs/projects/acad/mmfusion/projects/multimodal_framework/best_models/best_model_MLP_epoch=9-step=170-v2.ckpt"


tab_encoder = TabEncoderMLP(
            freeze=False,
            include_head=False,
          )

classifier = CustomClassifier(
    hidden_dim=0,  
    activation_fun=nn.ReLU(),
    num_class=2,  
    task="multiclass",  
    dropout_rate=0
)

encoders = {"tabular": tab_encoder}
unimodal_model = FusionModel(
    encoders=encoders,
    fusion=None,  # since we're only using MRIs only
    classifier=classifier
)

dummy_input = next(iter(test_dataloader))

unimodal_model(dummy_input)

ckpt = torch.load(model_ckpt, map_location="cpu")
state_dict = ckpt["state_dict"]
consume_prefix_in_state_dict_if_present(state_dict, "model.")
unimodal_model.load_state_dict(state_dict)

Initializing layers with input dimension: 32


  ckpt = torch.load(model_ckpt, map_location="cpu")


<All keys matched successfully>

# Tabular Model ClinicalBert (best for tabular modality)

## Data Loading

In [11]:
seed = 42
yaml_db_path="ADNI_dataset.yaml"
import yaml

with open(yaml_db_path, 'r') as f:
    dataset_config = yaml.safe_load(f)

dataset_config["drop_duplicates"] = True
dataset_config["normalize"] = False

class SimpleConfig:
        def __init__(self):
            self.batch_size = 1 
            self.num_workers = 4
            self.pin_memory = False
            self.persistent_workers = True
            self.dataset = dataset_config
            
        class model:
            weight_samples = False
    
cfg = SimpleConfig()

# Load dataset using the same pattern as data_utils.py
dataset_args = dict(dataset_config)


X_train, y_train, X_val, y_val, X_test, y_test, class_names, le = get_adni_dataset_new(**dataset_args,seed=seed)


collate_fn = MMClassificationCollator( 
    image_processing=None,
    tabular_processing=get_serialize_data_processing())


test_dataloader = data_loader(X_test, y_test,collate_fn=collate_fn, config=cfg, shuffle=True)

## Model Loading

In [12]:
#model_ckpt = "best_models/best_model_clinicalBert_epoch=7-step=72-v1.ckpt"
model_ckpt = "/gpfs/projects/acad/mmfusion/projects/multimodal_framework/output_runs_main/None_None_ClinicalBERTEncoder_None_None_ADNI MRI dataset New/epoch=8-step=81.ckpt"

tab_encoder = ClinicalBERTEncoder(
            freeze=False,
            include_head=False,
            checkpoint="medicalai/ClinicalBERT"
          )

classifier = CustomClassifier(
    hidden_dim=0,  
    activation_fun=nn.ReLU(),
    num_class=2,  
    task="multiclass",  
    dropout_rate=0
)

encoders = {"tabular": tab_encoder}
unimodal_model = FusionModel(
    encoders=encoders,
    fusion=None,  # since we're only using MRIs only
    classifier=classifier
)

dummy_input = next(iter(test_dataloader))

unimodal_model(dummy_input)



Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Initializing layers with input dimension: 768


tensor([[-0.6789, -0.2890]], grad_fn=<AddmmBackward0>)