In [1]:
from medvae import MVAE
import torch
from nilearn.plotting import view_img
import numpy as np
import nibabel as nib

## Download Sample Data

In [None]:
import gdown

# Download some data to play with
DEFAULT_DATA_URL = '13dHwCtAt9ou9Ee8PcrzTAGE2WNoEc6SV?usp=share_link'

gdown.download_folder(id=DEFAULT_DATA_URL, output="data", quiet=True)

print(f'Data downloaded.')

## Example with 2D MedVAE (f=16; C=3)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = MVAE(
    model_name='medvae_4_3_2d',
    modality='xray',
).to(device)
model.requires_grad_(False)
model.eval()

fpath = 'data/mmg_data/TQcBVJediTG8E34ftHnapA.png'

# Getting the transform and applying it
transform = model.get_transform()
img = transform(fpath).unsqueeze(0).to(device)

# Getting the latent representation
with torch.no_grad():
    latent = model(img).cpu().detach().numpy()

In [None]:
view_img(stat_map_img=nib.Nifti1Image(latent.transpose(1, 2, 0), np.eye(4)), bg_img=False, cmap='gray')

## CT Example with 3D MedVAE (f=16; C=1)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = MVAE(
    model_name='medvae_4_1_3d',
    modality='CT',
).to(device)
model.requires_grad_(False)
model.eval()

fpath = 'data/ct_data/sino_7858_0398.nii.gz'

# Apply the model transform -- easiest way
img = model.apply_transform(fpath).to(device)

# Getting the latent representation
with torch.no_grad():
    latent = model(img).cpu().detach().numpy()

In [None]:
view_img(stat_map_img=nib.Nifti1Image(latent, np.eye(4)), bg_img=False, cmap='gray')

## MRI Example with 3D MedVAE (f=16; C=1)

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = MVAE(
    model_name='medvae_4_1_3d',
    modality='MRI',
).to(device)
model.requires_grad_(False)
model.eval()

fpath = 'data/mri_data/t1oasis_case_1286.nii.gz'

# Apply the model transform -- easiest way
img = model.apply_transform(fpath).to(device)

# Getting the latent representation
with torch.no_grad():
    latent = model(img).cpu().detach().numpy()

In [None]:
view_img(stat_map_img=nib.Nifti1Image(latent, np.eye(4)), bg_img=False, cmap='gray')