# OASIS Multimodal Classification
### OASIS Dataset
OASIS-1: Cross-Sectional: Principal Investigators: D. Marcus, R, Buckner, J, Csernansky J. Morris; P50 AG05681, P01 AG03991, P01 AG026276, R01 AG021910, P20 MH071616, U24 RR021382
Open Access Series of Imaging Studies (OASIS): Cross-Sectional MRI Data in Young, Middle Aged, Nondemented, and Demented Older Adults. Marcus, DS, Wang, TH, Parker, J, Csernansky, JG, Morris, JC, Buckner, RL. Journal of Cognitive Neuroscience, 19, 1498-1507. doi: 10.1162/jocn.2007.19.9.1498

https://sites.wustl.edu/oasisbrains/home/oasis-1/

https://www.kaggle.com/datasets/ninadaithal/oasis-1-shinohara

### convert to nifti
https://artiiicy.tistory.com/70

https://brainder.org/2011/08/13/converting-oasis-brains-to-nifti/

https://fsl.fmrib.ox.ac.uk/fsl/docs/#/install/index

or using nibabel library

https://neurostars.org/t/convert-hdr-img-to-nii-fromat/3761

### MedicalNet

    @article{chen2019med3d,
        title={Med3D: Transfer Learning for 3D Medical Image Analysis},
        author={Chen, Sihong and Ma, Kai and Zheng, Yefeng},
        journal={arXiv preprint arXiv:1904.00625},
        year={2019}
    }

https://github.com/Tencent/MedicalNet/tree/master

https://www.kaggle.com/datasets/werus23/medicalnet

### multimodal embedding

based on CLIP

https://velog.io/@ji1kang/paper-reading-clip


In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import os
import sys
import glob
import torch
import random
import shutil
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import nibabel as nib
import numpy as np
import argparse
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
from sklearn.model_selection import train_test_split

In [3]:
!git clone https://github.com/Tencent/MedicalNet
!mv MedicalNet_pytorch_files.zip MedicalNet/.
!cd MedicalNet && unzip MedicalNet_pytorch_files.zip

Cloning into 'MedicalNet'...
remote: Enumerating objects: 120, done.[K
remote: Counting objects: 100% (69/69), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 120 (delta 50), reused 43 (delta 43), pack-reused 51 (from 1)[K
Receiving objects: 100% (120/120), 47.63 MiB | 16.13 MiB/s, done.
Resolving deltas: 100% (50/50), done.
mv: cannot stat 'MedicalNet_pytorch_files.zip': No such file or directory
unzip:  cannot find or open MedicalNet_pytorch_files.zip, MedicalNet_pytorch_files.zip.zip or MedicalNet_pytorch_files.zip.ZIP.


In [4]:
!cp "/content/drive/MyDrive/KUBIG/CV contest/Pretrained/resnet_10.pth" "/content/MedicalNet/"

In [5]:
!cp "/content/drive/MyDrive/KUBIG/CV contest/nifti.zip" "/content/"
!unzip nifti.zip
!rm nifti.zip

Archive:  nifti.zip
   creating: nifti/
   creating: nifti/OAS1_0303_MR1/
   creating: nifti/OAS1_0317_MR1/
   creating: nifti/OAS1_0101_MR1/
   creating: nifti/OAS1_0115_MR1/
   creating: nifti/OAS1_0129_MR1/
   creating: nifti/OAS1_0368_MR2/
   creating: nifti/OAS1_0075_MR1/
   creating: nifti/OAS1_0061_MR1/
   creating: nifti/OAS1_0049_MR1/
   creating: nifti/OAS1_0288_MR1/
   creating: nifti/OAS1_0156_MR2/
   creating: nifti/OAS1_0277_MR1/
   creating: nifti/OAS1_0263_MR1/
   creating: nifti/OAS1_0262_MR1/
   creating: nifti/OAS1_0289_MR1/
   creating: nifti/OAS1_0060_MR1/
   creating: nifti/OAS1_0074_MR1/
   creating: nifti/OAS1_0114_MR1/
   creating: nifti/OAS1_0316_MR1/
   creating: nifti/OAS1_0302_MR1/
   creating: nifti/OAS1_0314_MR1/
   creating: nifti/OAS1_0300_MR1/
   creating: nifti/OAS1_0328_MR1/
   creating: nifti/OAS1_0116_MR1/
   creating: nifti/OAS1_0102_MR1/
   creating: nifti/OAS1_0062_MR1/
   creating: nifti/OAS1_0076_MR1/
   creating: nifti/OAS1_0260_MR1/
   creat

In [6]:
!cp "/content/drive/MyDrive/KUBIG/CV contest/Processed_OASIS_MRI_Data.csv" "/content/"

In [7]:
mri_folder = "/content/nifti/"
df = pd.read_csv("/content/Processed_OASIS_MRI_Data.csv")
mri_files = glob.glob(os.path.join(mri_folder, "**/*.nii.gz"))
print(f"총 {len(mri_files)}개의 MRI 파일을 찾았습니다.")
df.head()

총 1688개의 MRI 파일을 찾았습니다.


Unnamed: 0,ID,M/F,CDR,Age,Group,CLIP_Text
0,OAS1_0001_MR1,F,0.0,74,NonDemented,This is an MRI scan of a 74-year-old female wi...
1,OAS1_0002_MR1,F,0.0,55,NonDemented,This is an MRI scan of a 55-year-old female wi...
2,OAS1_0003_MR1,F,0.5,73,VeryMildDementia,This is an MRI scan of a 73-year-old female wi...
3,OAS1_0004_MR1,M,0.0,28,NonDemented,This is an MRI scan of a 28-year-old male with...
4,OAS1_0005_MR1,M,0.0,18,NonDemented,This is an MRI scan of a 18-year-old male with...


In [8]:
# Dataset loading

def extract_id_from_filename(filename):
    basename = os.path.basename(filename)
    parts = basename.split('_')
    return parts[1]

def extract_id_from_csv(original_id):
    parts = original_id.split('_')
    return parts[1]

df['ID'] = df['ID'].apply(extract_id_from_csv)

matched_data = []
for filename in mri_files:
    id = extract_id_from_filename(filename)
    if id is not None:
        row = df[df['ID'] == id]
        if not row.empty:
            matched_data.append({"mri_path" : filename,
                                 "age" : row.iloc[0]["Age"],
                                 "gender" : "male" if row.iloc[0]["M/F"] == "M" else "female",
                                 "label" : row.iloc[0]["Group"]})


print(f"총 {len(matched_data)}개의 MRI 데이터가 CSV와 매칭되었습니다.")
matched_data[0]

총 1688개의 MRI 데이터가 CSV와 매칭되었습니다.


{'mri_path': '/content/nifti/OAS1_0313_MR1/OAS1_0313_MR1_mpr-4_anon.nii.gz',
 'age': 20,
 'gender': 'female',
 'label': 'NonDemented'}

In [27]:
class OASISDataset(Dataset):
    def __init__(self, matched_data, transform=None):
        self.data = matched_data
        self.transform = transform

        self.label_mapping = {
            "NonDemented" : 0,
            "VeryMildDementia" : 1,
            "MildDementia" : 2,
            "ModerateDementia" : 3,
        }

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        mri_path, age, gender, label = sample['mri_path'], sample['age'], sample['gender'], sample['label']

        nii_img = nib.load(mri_path)
        img_data = nii_img.get_fdata()
        img_data = (img_data - np.min(img_data)) / (np.max(img_data) - np.min(img_data))

        img_data = torch.tensor(img_data, dtype=torch.float32).unsqueeze(0)
        img_data = F.interpolate(img_data.unsqueeze(0), size=(16, 112, 112), mode='trilinear', align_corners=False).squeeze(0)

        text_data = f"This is an MRI scan of a {age}-years-old {gender} patient."
        label = self.label_mapping[label]

        return img_data, text_data, torch.tensor(label, dtype=torch.long)

In [None]:
# random seed 고정

seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
generator = torch.Generator().manual_seed(seed)

# train, test split

matched_data = random.sample(matched_data, len(matched_data))
labels = [sample['label'] for sample in matched_data]

train_data, temp_data, train_labels, temp_labels = train_test_split(matched_data, labels, test_size=0.2, stratify=labels, random_state=seed)
val_data, test_data, val_labels, test_labels = train_test_split(temp_data, temp_labels, test_size=0.5, stratify=temp_labels, random_state=seed)

train_dataset = OASISDataset(train_data)
val_dataset = OASISDataset(val_data)
test_dataset = OASISDataset(test_data)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, generator=generator)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

print(f"Train size: {len(train_dataset)}")
print(f"Validation size: {len(val_dataset)}")
print(f"Test size: {len(test_dataset)}")

Train size: 1350
Validation size: 169
Test size: 169


In [29]:
# checking GPU

print(torch.cuda.device_count())
print(torch.cuda.is_available())
print(torch.cuda.current_device())
print(torch.cuda.device(0))
!nvidia-smi

1
True
0
<torch.cuda.device object at 0x7a4e54587710>
Sun Feb 23 11:06:18 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   33C    P0             55W /  400W |    7961MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+------------------

In [30]:
# Calling MedicalNet pretrained image encoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sys.path.append("/content/MedicalNet")

from MedicalNet.model import generate_model

class MedicalNetEncoder(nn.Module):
    def __init__(self, feature_dim=32, use_pretrained=False):
        super(MedicalNetEncoder, self).__init__()
        opt = argparse.Namespace(model='resnet',
                                 model_depth=10,   # depth 설정에 따라 다른 모델 불러옴
                                 input_W=112,
                                 input_H=112,
                                 input_D=16,
                                 resnet_shortcut='B',
                                 no_cuda=False,
                                 n_seg_classes=4,
                                 gpu_id=[0],
                                 phase='train',
                                 pretrain_path="/content/MedicalNet/resnet_10.pth" if use_pretrained else None,
                                 new_layer_names=[])

        self.model, _ = generate_model(opt)
        self.model.fc = nn.Identity()

        self.model.to(device)

    def forward(self, x):
        return self.model(x)

mri_encoder = MedicalNetEncoder(feature_dim=512, use_pretrained=True)

loading pretrained model /content/MedicalNet/resnet_10.pth


  m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
  pretrain = torch.load(opt.pretrain_path)


In [31]:
# CLIP text encoder

class TextEncoder(nn.Module):
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        super(TextEncoder, self).__init__()
        self.device = device
        self.model = CLIPModel.from_pretrained(model_name).to(device)   # Hugging Face에서 학습된 CLIP model
        self.processor = CLIPProcessor.from_pretrained(model_name)   # tokenizing, padding, truncation 자동 수행

    def forward(self, text):
        if isinstance(text, list):
            text = text[0]   # 첫 번째 문자열만 text로 사용

        inputs = self.processor(text=text, return_tensors="pt", padding=True, truncation=True).to(self.device)
        text_features = self.model.get_text_features(**inputs)
        return text_features

text_encoder = TextEncoder()
sample_text = ["This is a sample text."]
text_features = text_encoder(sample_text)
print(text_features.shape)

torch.Size([1, 512])


In [None]:
# CLIP model처럼 multimodal embedding에서 dot product로 계산

class MultimodalClassifier(nn.Module):
    def __init__(self, feature_dim=512, num_classes=4):
        super(MultimodalClassifier, self).__init__()

        self.mri_fc = nn.Linear(12544, feature_dim)   # input size 확인 후 입력, (batch_size x dim)이면 dim 입력
        self.text_fc = nn.Linear(512, feature_dim)
        self.temperature = nn.Parameter(torch.tensor(1.0))

    def forward(self, mri_features, text_features):

        mri_features = mri_features.view(mri_features.size(0), -1)
        mri_features = self.mri_fc(mri_features)
        text_features = self.text_fc(text_features)

        mri_features = F.normalize(mri_features, p=2, dim=1)
        text_features = F.normalize(text_features, p=2, dim=1)

        similarity = torch.matmul(mri_features, text_features.T) * torch.exp(self.temperature)   # dot product 계산
        return similarity

def symmetric_loss(logits):
    labels = torch.arange(logits.size(0)).to(logits.device)
    loss_i2t = F.cross_entropy(logits, labels)
    loss_t2i = F.cross_entropy(logits.T, labels)
    return (loss_i2t + loss_t2i) / 2

multimodal_model = MultimodalClassifier(feature_dim=512, num_classes=4)

In [None]:
mri_encoder.to(device)
text_encoder.to(device)
multimodal_model.to(device)

optimizer = torch.optim.Adam(list(mri_encoder.parameters()) +
                             list(text_encoder.parameters()) +
                             list(multimodal_model.parameters()),
                             lr=0.0001)

num_epochs = 100

for epoch in range(num_epochs):   # start_epoch, num_epochs
    mri_encoder.train()
    text_encoder.train()
    multimodal_model.train()

    total_train_loss, total_train_correct, total_train_samples = 0, 0, 0

    for batch in train_loader:
        mri_tensor, text_inputs, labels = batch
        mri_tensor, labels = mri_tensor.to(device), labels.to(device)

        mri_features = mri_encoder(mri_tensor)
        text_features = text_encoder(text_inputs).to(device)

        logits = multimodal_model(mri_features, text_features)
        loss = symmetric_loss(logits)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        preds = torch.argmax(logits, dim=1)
        correct = (preds == torch.arange(logits.size(0)).to(device)).sum().item()

        total_train_loss += loss.item()
        total_train_correct += correct
        total_train_samples += logits.size(0)

    train_loss = total_train_loss / len(train_loader)
    train_accuracy = total_train_correct / total_train_samples

    mri_encoder.eval()
    text_encoder.eval()
    multimodal_model.eval()

    total_val_loss, total_val_correct, total_val_samples = 0, 0, 0

    with torch.no_grad():
        for batch in val_loader:
            mri_tensor, text_inputs, labels = batch
            mri_tensor, labels = mri_tensor.to(device), labels.to(device)

            mri_features = mri_encoder(mri_tensor)
            text_features = text_encoder(text_inputs).to(device)

            logits = multimodal_model(mri_features, text_features)
            loss = symmetric_loss(logits)

            preds = torch.argmax(logits, dim=1)
            correct = (preds == torch.arange(logits.size(0)).to(device)).sum().item()

            total_val_loss += loss.item()
            total_val_correct += correct
            total_val_samples += logits.size(0)

    val_loss = total_val_loss / len(val_loader)
    val_accuracy = total_val_correct / total_val_samples

    print(f"Epoch [{epoch + 1}/{num_epochs}] | "
          f"Train Loss : {train_loss:.4f}, Train Acc : {train_accuracy:.4f} | "
          f"Val Loss : {val_loss:.4f}, Val Acc : {val_accuracy:.4f}")

Epoch [1/100] | Train Loss : 3.3101, Train Acc : 0.0637 | Val Loss : 3.4024, Val Acc : 0.0355
Epoch [2/100] | Train Loss : 3.0723, Train Acc : 0.0807 | Val Loss : 2.8577, Val Acc : 0.1183
Epoch [3/100] | Train Loss : 2.7265, Train Acc : 0.1267 | Val Loss : 2.4370, Val Acc : 0.1479
Epoch [4/100] | Train Loss : 2.5738, Train Acc : 0.1481 | Val Loss : 2.2451, Val Acc : 0.2130
Epoch [5/100] | Train Loss : 2.4079, Train Acc : 0.1889 | Val Loss : 2.1718, Val Acc : 0.1834
Epoch [6/100] | Train Loss : 2.2961, Train Acc : 0.2000 | Val Loss : 2.0783, Val Acc : 0.2249
Epoch [7/100] | Train Loss : 2.2018, Train Acc : 0.2333 | Val Loss : 2.0359, Val Acc : 0.2130
Epoch [8/100] | Train Loss : 2.1323, Train Acc : 0.2800 | Val Loss : 1.9279, Val Acc : 0.3254
Epoch [9/100] | Train Loss : 2.0695, Train Acc : 0.3156 | Val Loss : 1.8981, Val Acc : 0.3314
Epoch [10/100] | Train Loss : 2.0106, Train Acc : 0.3459 | Val Loss : 1.8418, Val Acc : 0.3728
Epoch [11/100] | Train Loss : 1.9427, Train Acc : 0.3837 | 

In [34]:
save_path = 'checkpoint.pth'

torch.save({
    'epoch' : epoch,
    'mri_encoder_state_dict' : mri_encoder.state_dict(),
    'text_encoder_state_dict' : text_encoder.state_dict(),
    'multimodal_model_state_dict' : multimodal_model.state_dict(),
    'optimizer_state_dict' : optimizer.state_dict(),
    'loss' : loss,
}, save_path)

In [None]:
mri_encoder.eval()
text_encoder.eval()
multimodal_model.eval()

total_test_loss, total_test_correct, total_test_samples = 0, 0, 0

with torch.no_grad():
    for batch in test_loader:
        mri_tensor, text_inputs, labels = batch
        mri_tensor, labels = mri_tensor.to(device), labels.to(device)

        mri_features = mri_encoder(mri_tensor)
        text_features = text_encoder(text_inputs)

        logits = multimodal_model(mri_features, text_features)
        loss = symmetric_loss(logits)

        preds = torch.argmax(logits, dim=1)
        correct = (preds == torch.arange(logits.size(0)).to(device)).sum().item()


        total_test_loss += loss.item()
        total_test_correct += correct
        total_test_samples += logits.size(0)

test_loss = total_test_loss / len(test_loader)
test_accuracy = total_test_correct / total_test_samples

print(f"Test Results | Loss : {test_loss:.4f} | Accuracy : {test_accuracy:.4f}")

KeyboardInterrupt: 