In [3]:
import torch
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import sys
sys.path.append('..')
from model import PRISM
from dataset import PRISM_MRI_Dataset
sys.path.remove('..')

# Loading Models and Data

In [None]:
## Model and data initialization
'''
Harmonization method 1 (preferred): Uniform harmonization - reduces inter-site AND intra-site variations
Harmonization method 2: Image-wise harmonization
'''
harmonization_method = 1
data_dir = '/kaggle/input'
model_dir = '/kaggle/working'
out_dir = f'/kaggle/working/harmonized_{harmonization_method}'
modality = 'T2'
target='Guys'

prism_g = PRISM(intensity_levels=5, latent_dim=2, num_sites=4, gpu_id=0, modality=modality, modalities = ['T2'])
prism_h = PRISM(intensity_levels=5, latent_dim=2, num_sites=4, gpu_id=0, modality=modality, modalities = ['T2'])
prism_i = PRISM(intensity_levels=5, latent_dim=2, num_sites=4, gpu_id=0, modality=modality, modalities = ['T2'])
prism_a = PRISM(intensity_levels=5, latent_dim=2, num_sites=4, gpu_id=0, modality=modality, modalities = ['T2'])

# Target site: Guys
prism_g.anatomy_encoder.load_state_dict(torch.load(f'{model_dir}/prism-anatomy-encoder_guys.pth', map_location='cpu'))
prism_g.style_encoder.load_state_dict(torch.load(f'{model_dir}/prism-style-encoder_guys.pth', map_location='cpu'))
prism_g.decoder.load_state_dict(torch.load(f'{model_dir}/prism-decoder_guys.pth', map_location='cpu'))

# Source sites: HH, IOP and ADNI
prism_h.anatomy_encoder.load_state_dict(torch.load(f'{model_dir}/prism-anatomy-encoder_hh.pth', map_location='cpu'))
prism_i.anatomy_encoder.load_state_dict(torch.load(f'{model_dir}/prism-anatomy-encoder_iop.pth', map_location='cpu'))
prism_a.anatomy_encoder.load_state_dict(torch.load(f'{model_dir}/prism-anatomy-encoder_adni1.pth', map_location='cpu'))

In [None]:
# Load both train and test datasets for each site
guys_ds_train = torch.load(f'{data_dir}/IXI-Guys-train.pt', map_location='cpu')
guys_ds_test = torch.load(f'{data_dir}/IXI-Guys-test.pt', map_location='cpu')
hh_ds_train = torch.load(f'{data_dir}/IXI-HH-train.pt', map_location='cpu')
hh_ds_test = torch.load(f'{data_dir}/IXI-HH-test.pt', map_location='cpu')
iop_ds_train = torch.load(f'{data_dir}/IXI-IOP-train.pt', map_location='cpu')
iop_ds_test = torch.load(f'{data_dir}/IXI-IOP-test.pt', map_location='cpu')
adni_ds_train = torch.load(f'{data_dir}/ADNI1-train.pt', map_location='cpu')
adni_ds_test = torch.load(f'{data_dir}/ADNI1-test.pt', map_location='cpu')

datasets = {
    'Guys': {'model': prism_g, 'data': [(guys_ds_train, 'train'), (guys_ds_test, 'test')]},
    'HH': {'model': prism_h, 'data': [(hh_ds_train, 'train'), (hh_ds_test, 'test')]},
    'IOP': {'model': prism_i, 'data': [(iop_ds_train, 'train'), (iop_ds_test, 'test')]},
    'ADNI1': {'model': prism_a, 'data': [(adni_ds_train, 'train'), (adni_ds_test, 'test')]}
}

# Harmonization

In [None]:
## Harmonization
'''
Harmonization method 1 (preferred): Uniform harmonization - reduces inter-site AND intra-site variations
Harmonization method 2: Image-wise harmonization
'''

style_codes = {site: [] for site in datasets.keys()}
style_codes_og = {site: [] for site in ['HH', 'IOP', 'ADNI1']}
anatomies = {site: [] for site in datasets.keys()}
masks = {site: [] for site in datasets.keys()}
names = {site: [] for site in datasets.keys()}

if harmonization_method == 1:

    with torch.set_grad_enabled(False):
        # Set all models to eval mode
        for site_info in datasets.values():
            site_info['model'].anatomy_encoder.eval()
            site_info['model'].style_encoder.eval()

        # Process each site
        for site, site_info in datasets.items():
            if not os.path.exists(f'{out_dir}/{site}'):
                os.makedirs(f'{out_dir}/{site}')
            
            for dataset, data_mode in site_info['data']:
                for subject in dataset:
                    # ... rest of your processing code, using site_info['model'] instead of individual models ...
                    image = subject[site_info['model'].modality]['image'].to(site_info['model'].device).unsqueeze(1)
                    mask = subject[site_info['model'].modality]['mask'].to(site_info['model'].device).unsqueeze(1)
                    _, anatomy = site_info['model'].get_anatomy_representations(image, mask)
                    style_code, _, _ = prism_g.get_style_code(image)

                    anatomies[site].append(anatomy.detach().squeeze())
                    style_codes[site].append(style_code.detach().cpu().squeeze())
                    masks[site].append(mask.detach().squeeze())
                    names[site].append(subject[site_info['model'].modality]['subject_id'])

                    if site != target:
                        style_code_og, _, _ = site_info['model'].get_style_code(image)
                        style_codes_og[site].append(style_code_og.detach().cpu().squeeze())

    # Uniform Harmonization
    for site, site_info in datasets.values():
        style_code = torch.mean(style_codes[site], dim=0).unsqueeze(1).unsqueeze(1).unsqueeze(0).to(site_info['model'].device)
        with torch.set_grad_enabled(False):
            prism_g.decoder.eval()

            for i in range(len(anatomies[site])):
                anatomy = anatomies[site][i].unsqueeze(0).unsqueeze(0)
                mask = masks[site][i].unsqueeze(0)
                harmonized = prism_g.decode(anatomy, style_code, mask)
                subid = names[site][i]
                if not os.path.exists(f'{out_dir}/{site}/{data_mode}/{subid}'):
                    os.makedirs(f'{out_dir}/{site}/{data_mode}/{subid}')
                
                plt.imsave(f'{out_dir}/{site}/{data_mode}/{subid}/IXI-{site}-{subid}-{modality}_harmonized.png', harmonized.squeeze().cpu().numpy(), cmap='gray')

elif harmonization_method==2:
    with torch.set_grad_enabled(False):
        # Set all models to eval mode
        for site_info in datasets.values():
            site_info['model'].anatomy_encoder.eval()
        prism_g.style_encoder.eval()
        prism_g.decoder.eval()

        # Process each site
        for site, site_info in datasets.items():
            if not os.path.exists(f'{out_dir}/{site}'):
                os.makedirs(f'{out_dir}/{site}')
            
            for dataset, data_mode in site_info['data']:
                # Image-level harmonization
                for subject in dataset:
                    image = subject[site_info['model'].modality]['image'].to(site_info['model'].device).unsqueeze(1)
                    mask = subject[site_info['model'].modality]['mask'].to(site_info['model'].device).unsqueeze(1)
                    _, anatomy = site_info['model'].get_anatomy_representations(image, mask)
                    style_code, _, _ = prism_g.get_style_code(image)
                    harmonized = prism_g.decode(anatomy, style_code, mask)
                    subid = subject[site_info['model'].modality]['subject_id']
                    if not os.path.exists(f'{out_dir}/{site}/{data_mode}/{subid}'):
                        os.makedirs(f'{out_dir}/{site}/{data_mode}/{subid}')
                    
                    plt.imsave(f'{out_dir}/{site}/{data_mode}/{subid}/IXI-{site}-{subid}-{modality}_harmonized.png', harmonized.squeeze().cpu().numpy(), cmap='gray')

# Latent Style Visualization
Pre- and Post- Harmonization

In [None]:
# OPTIONAL: Save original and translated latent style codes for each site - for visualization.py
if len(style_codes[target])!=0:
    for site in datasets.keys():
        with open(f'{model_dir}/style_codes_{site.lower()}.pkl', 'wb') as f:
            pickle.dump(style_codes[site], f)
        if site!=target:
            with open(f'{model_dir}/style_codes_{site.lower()}_og.pkl', 'wb') as f:
                pickle.dump(style_codes_og[site], f)

In [1]:
# Stack the style tensor lists for easier handling
styles = {site: None for site in datasets.keys()}
styles_og = {site: None for site in ['HH', 'IOP', 'ADNI1']}
for site in datasets.keys():
    styles[site] = torch.stack(style_codes[site])
    if site!=target:
        styles_og[site] = torch.stack(style_codes[site])

NameError: name 'datasets' is not defined

In [None]:
# Pre-harmonization latent style visualization

plt.figure(figsize=(12, 6))
marker = 'P'
colours = {'Guys': 'darkcyan',
           'HH': 'crimson',
           'IOP': 'springgreen',
           'ADNI1': 'orange'}
for site in datasets.keys():
    if site==target:
        sns.scatterplot(x=styles[site][:, 0], y=styles[site][:, 1], color=colours[site], marker=marker, label=f'Site {site}', s=50)    
    else:
        sns.scatterplot(x=styles_og[site][:, 0], y=styles_og[site][:, 1], color=colours[site], marker=marker, label=f'Site {site}', s=50)
plt.xlim(-25, 25)
plt.ylim(-25, 25)
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('Latent Style distributions (pre-harmonization)')
plt.legend(fontsize=15)
plt.show()

In [None]:
# Post-harmonization latent style visualization

plt.figure(figsize=(12, 6))
for site in datasets.keys():
    sns.scatterplot(x=styles[site][:, 0], y=styles[site][:, 1], color=colours[site], marker=marker, label=f'Site {site}', s=50)
plt.xlim(-25, 25)
plt.ylim(-25, 25)
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('Latent Style distributions (post-harmonization)')
plt.legend(fontsize=15)
plt.show()