In [1]:
from medvae import MVAE
import torch
from nilearn.plotting import view_img
import numpy as np
import nibabel as nib

## Download Sample Data

If you have trouble download from huggingface, you can delete local cache 'rm -rf .cache' and restart your jupter kernel.

In [None]:
from huggingface_hub import snapshot_download
import shutil

repo_id = "stanfordmimi/MedVAE"

# Download the example_data directory
local_dir = snapshot_download(
    repo_id=repo_id,
    allow_patterns=["example_data/*"],  # Only download files in example_data folder
    local_dir="./",                # Save to data directory
    max_workers=1,
    etag_timeout=10000,
    force_download=True,
)

# Rename the example_data directory to data
shutil.move("example_data", "data")

# Remove the .cache directory
shutil.rmtree(".cache")

print("Download completed")

## Example with 2D MedVAE (f=16; C=3)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = MVAE(
    model_name='medvae_4_3_2d',
    modality='xray',
).to(device)
model.requires_grad_(False)
model.eval()

fpath = 'data/mmg_data/TQcBVJediTG8E34ftHnapA.png'

# Getting the transform and applying it
transform = model.get_transform()
img = transform(fpath).unsqueeze(0).to(device)

# Getting the latent representation
with torch.no_grad():
    latent = model(img).cpu().detach().numpy()

In [None]:
view_img(stat_map_img=nib.Nifti1Image(latent.transpose(1, 2, 0), np.eye(4)), bg_img=False, cmap='gray')

## CT Example with 3D MedVAE (f=64; C=1)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = MVAE(
    model_name='medvae_4_1_3d',
    modality='CT',
).to(device)
model.requires_grad_(False)
model.eval()

fpath = 'data/ct_data/sino_7858_0398.nii.gz'

# Apply the model transform -- easiest way
img = model.apply_transform(fpath).to(device)

# Getting the latent representation
with torch.no_grad():
    latent = model(img).cpu().detach().numpy()

In [None]:
view_img(stat_map_img=nib.Nifti1Image(latent, np.eye(4)), bg_img=False, cmap='gray')

## MRI Example with 3D MedVAE (f=64; C=1)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = MVAE(
    model_name='medvae_4_1_3d',
    modality='MRI',
).to(device)
model.requires_grad_(False)
model.eval()

fpath = 'data/mri_data/t1oasis_case_1286.nii.gz'

# Apply the model transform -- easiest way
img = model.apply_transform(fpath).to(device)

# Getting the latent representation
with torch.no_grad():
    latent = model(img).cpu().detach().numpy()

In [None]:
view_img(stat_map_img=nib.Nifti1Image(latent, np.eye(4)), bg_img=False, cmap='gray')