# Hypothesis 1, Part 1: Running Unsupervised Model with MSE and NCC with No Affine Preprocessing

- In order to determine if label-driven approaches are required, we must experiment with models with loss functions driven by regularisation and similarity metrics
- Here, we run models using MSE and NCC using the literature recommended hyperparameters to demonstrate how useless they are by themselves. No affine prepocessing is used at first.

## Imports

In [None]:
# # imports
# import os, sys

# # third party imports
# import numpy as np
# import tensorflow as tf
# assert tf.__version__.startswith('2.'), 'This tutorial assumes Tensorflow 2.0+'

# # local imports
# import voxelmorph as vxm
# import neurite as ne
# imports
import os, sys

# third party imports
import numpy as np
import tensorflow as tf
assert tf.__version__.startswith('2.'), 'This tutorial assumes Tensorflow 2.0+'
from tensorflow.keras import layers
from tensorflow import keras

# local imports
import voxelmorph as vxm
import neurite as ne


from skimage.transform import resize
import nibabel as nib

import matplotlib.pyplot as plt

import os, re
import numpy as np
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import SimpleITK as sitk
import csv

import os
#os.environ["CUDA_VISIBLE_DEVICES"]=""

import tensorflow as tf
import numpy as np
from scipy.ndimage import _ni_support
from scipy.ndimage import generate_binary_structure, distance_transform_edt, binary_erosion

#from data_generator import resize_3d_image
from tqdm import tqdm

In [None]:

def run_experiments(name, intensity_metric, weakly_supervised, train_path, val_path ,Verbose=False):
    # ======================================
    # BUILD MODEL ARCHITECTURE
    # ======================================
    
    from tensorflow.keras import layers
    from tensorflow import keras
    
    def get_model(moving_image_shape, fixed_image_shape, with_label_inputs=True, up_filters=[64, 128, 256], down_filters=[256, 128, 64, 32]):
    
        input_moving_image = keras.Input(moving_image_shape)
        input_fixed_image = keras.Input(fixed_image_shape)
    
        if with_label_inputs:
            input_moving_label = keras.Input(moving_image_shape)
            input_fixed_label = keras.Input(fixed_image_shape)
    
        concatenate_layer = layers.Concatenate(axis=-1)([input_moving_image, input_fixed_image])
    
        ### [First half of the network: downsampling inputs] ###
    
        # Entry block
        x = layers.Conv3D(32, 3, strides=2, padding="same")(concatenate_layer)
        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)
    
        previous_block_activation = x  # Set aside residual
    
        # Blocks 1, 2, 3 are identical apart from the feature depth.
        for filters in up_filters:
            x = layers.Activation("relu")(x)
            x = layers.Conv3D(filters, 3, padding="same")(x)
            x = layers.BatchNormalization()(x)
    
            x = layers.Activation("relu")(x)
            x = layers.Conv3D(filters, 3, padding="same")(x)
            x = layers.BatchNormalization()(x)
    
            x = layers.MaxPooling3D(3, strides=2, padding="same")(x)
    
            # Project residual
            residual = layers.Conv3D(filters, 1, strides=2, padding="same")(
                previous_block_activation
            )
            x = layers.add([x, residual])  # Add back residual
            previous_block_activation = x  # Set aside next residual
    
        ### [Second half of the network: upsampling inputs] ###
    
        for filters in down_filters:
            x = layers.Activation("relu")(x)
            x = layers.Conv3DTranspose(filters, 3, padding="same")(x)
            x = layers.BatchNormalization()(x)
    
            x = layers.Activation("relu")(x)
            x = layers.Conv3DTranspose(filters, 3, padding="same")(x)
            x = layers.BatchNormalization()(x)
    
            x = layers.UpSampling3D(2)(x)
    
            # Project residual
            residual = layers.UpSampling3D(2)(previous_block_activation)
            residual = layers.Conv3D(filters, 1, padding="same")(residual)
            x = layers.add([x, residual])  # Add back residual
            previous_block_activation = x  # Set aside next residual
    
        # Add a per-pixel classification layer
        out_ddf = layers.Conv3D(3, 3, activation="linear", padding="same")(x)
    
        # Define the model
        if with_label_inputs:
            model = keras.Model(inputs=[input_moving_image, input_fixed_image, input_moving_label, input_fixed_label], outputs=[out_ddf])
        else:
            model = keras.Model(inputs=[input_moving_image, input_fixed_image], outputs=[out_ddf])
        return model
    
    # ======================================
    # CACHE RELEVENT DATA
    # ======================================
    
    def resize_3d_image(image, shape):
        resized_image = resize(image, output_shape=shape)
        if np.amax(resized_image) == np.amin(resized_image):
            normalised_image = resized_image
        else:
            normalised_image = (resized_image-np.amin(resized_image))/(np.amax(resized_image)-np.amin(resized_image))
        return normalised_image
    
    def load_dataset_into_cache(f_path, moving_image_shape, fixed_image_shape, with_label_inputs=True):
        moving_images_path = os.path.join(f_path, 'us_images')
        fixed_images_path = os.path.join(f_path, 'mr_images')
    
        all_names = np.array(os.listdir(fixed_images_path))
    
        cache = {}
    
        for f_name in all_names:
            moving_image = nib.load(os.path.join(moving_images_path, f_name)).get_fdata()
            fixed_image = nib.load(os.path.join(fixed_images_path, f_name)).get_fdata()
    
            moving_image_resized = resize_3d_image(moving_image, moving_image_shape)
            fixed_image_resized = resize_3d_image(fixed_image, fixed_image_shape)
    
            entry = {
                "moving": moving_image_resized,
                "fixed": fixed_image_resized,
            }
    
            if with_label_inputs:
                moving_labels_path = os.path.join(f_path, 'us_labels')
                fixed_labels_path = os.path.join(f_path, 'mr_labels')
    
                moving_label = nib.load(os.path.join(moving_labels_path, f_name)).get_fdata()
                fixed_label = nib.load(os.path.join(fixed_labels_path, f_name)).get_fdata()
    
                entry["moving_label"] = moving_label
                entry["fixed_label"] = fixed_label
    
            cache[f_name] = entry
    
        return cache
    
    # =========================================
    # DEFINING THE SHAPES WE ARE WORKING WITH
    # =========================================
    
    moving_image_shape = (64, 64, 64, 1)
    fixed_image_shape = (64, 64, 64, 1)
    
    # train_cache = load_dataset_into_cache("nifti_data_preprocessed/train", moving_image_shape, moving_image_shape, with_label_inputs=True)
    # test_cache = load_dataset_into_cache("nifti_data_preprocessed/val", moving_image_shape, moving_image_shape, with_label_inputs=True)
    train_cache = load_dataset_into_cache(train_path, moving_image_shape, moving_image_shape, with_label_inputs=True)
    test_cache = load_dataset_into_cache(val_path, moving_image_shape, moving_image_shape, with_label_inputs=True)
    
    # =========================================
    # TRAIN AND TEST GEN
    # =========================================
    
    
    def train_generator_(cache, batch_size, moving_image_shape, fixed_image_shape, with_label_inputs=True):
        
        all_names = list(cache.keys())
    
        while True:
            batch_names = np.random.permutation(all_names)[:batch_size]
    
            moving_images_batch = np.zeros((batch_size, *moving_image_shape))
            fixed_images_batch = np.zeros((batch_size, *fixed_image_shape))
    
            if with_label_inputs:
                moving_labels_batch = np.zeros((batch_size, *moving_image_shape))
                fixed_labels_batch = np.zeros((batch_size, *fixed_image_shape))
    
            for i, f_name in enumerate(batch_names):
                entry = cache[f_name]
                moving = entry["moving"]
                fixed  = entry["fixed"]
    
                if with_label_inputs:
                    label_to_select = np.random.randint(6)
                    moving_label = resize_3d_image(entry["moving_label"][:, :, :, label_to_select], moving_image_shape)
                    fixed_label  = resize_3d_image(entry["fixed_label"][:, :, :, label_to_select], fixed_image_shape)
                else:
                    moving_label, fixed_label = None, None
    
                # assign into batch
                moving_images_batch[i] = moving
                fixed_images_batch[i]  = fixed
                if with_label_inputs:
                    moving_labels_batch[i] = moving_label
                    fixed_labels_batch[i]  = fixed_label
    
            zero_phis = np.zeros([batch_size, *moving_image_shape[:-1], 3])
    
            if with_label_inputs:
                inputs = (moving_images_batch, fixed_images_batch, moving_labels_batch, fixed_labels_batch)
                outputs = (fixed_images_batch, zero_phis, fixed_labels_batch)
            else:
                inputs = (moving_images_batch, fixed_images_batch)
                outputs = (fixed_images_batch, zero_phis)
    
            yield inputs, outputs
    
    
    
    def test_generator(cache, batch_size, moving_image_shape, fixed_image_shape, start_index, end_index, label_num, with_label_inputs=True):
        all_names = list(cache.keys())[start_index:end_index]
        n_steps = int(np.floor(len(all_names) / batch_size))
    
        for step in range(n_steps):
            batch_names = all_names[step*batch_size:(step+1)*batch_size]
    
            moving_images_batch = np.zeros((batch_size, *moving_image_shape))
            fixed_images_batch = np.zeros((batch_size, *fixed_image_shape))
    
            if with_label_inputs:
                moving_labels_batch = np.zeros((batch_size, *moving_image_shape))
                fixed_labels_batch = np.zeros((batch_size, *fixed_image_shape))
    
            for i, f_name in enumerate(batch_names):
                entry = cache[f_name]
                moving_images_batch[i] = entry["moving"]
                fixed_images_batch[i] = entry["fixed"]
    
                if with_label_inputs:
                    moving_labels_batch[i] = resize_3d_image(entry["moving_label"][:, :, :, label_num], moving_image_shape)
                    fixed_labels_batch[i] = resize_3d_image(entry["fixed_label"][:, :, :, label_num], fixed_image_shape)
    
            zero_phis = np.zeros([batch_size, *moving_image_shape[:-1], 3])
    
            if with_label_inputs:
                inputs = [moving_images_batch, fixed_images_batch, moving_labels_batch, fixed_labels_batch]
                outputs = [fixed_images_batch, zero_phis, fixed_labels_batch]
            else:
                inputs = [moving_images_batch, fixed_images_batch]
                outputs = [fixed_images_batch, zero_phis]
    
            yield inputs, outputs
    
    
    # ==============================================
    # TRAINING LOOP
    # ==============================================
    
    
    os.environ['CUDA_VISIBLE_DEVICES']='1'
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
    
    import matplotlib.pyplot as plt
    
    def train_model(similarity_metric, weak_supervision, last_trial=None, latest_weights=None, Verbose=False):
        model_save_path = f'H1_Experiments/{name}_Part_1_Checkpoints_More_Epochs'
        if not os.path.exists(model_save_path):
            os.mkdir(model_save_path)
        
        for lambda_param in [0.1]: # incase you want to tune the regularization scalar
            
            # =============================================================================
            # Build the backbone model
            # =============================================================================
            
            moving_image_shape = (64, 64, 64, 1)
            fixed_image_shape = (64, 64, 64, 1)
            
            model = get_model(moving_image_shape, fixed_image_shape, with_label_inputs=False)
            
            print('\nBackbone model inputs and outputs:')
            
            print('    input shape: ', ', '.join([str(t.shape) for t in model.inputs]))
            print('    output shape:', ', '.join([str(t.shape) for t in model.outputs]))
            
            # =============================================================================
            # Build the registration network
            # =============================================================================
            
            # build transformer layer
            spatial_transformer = vxm.layers.SpatialTransformer(name='transformer')
            
            # extract the moving image
            moving_image = model.input[0]
            fixed_image = model.inputs[1]
            input_moving_label = keras.Input(moving_image_shape, name="moving_label")
            input_fixed_label  = keras.Input(fixed_image_shape,  name="fixed_label")
            inputs=[moving_image, fixed_image, input_moving_label, input_fixed_label]
            
            # extract ddf
            ddf = model.outputs[0]
            
            # warp the moving image with the transformer using network-predicted ddf
            moved_image = spatial_transformer([moving_image, ddf])
            moved_label = spatial_transformer([input_moving_label, ddf])
            fixed_label = spatial_transformer([input_moving_label, ddf])*0 + input_fixed_label # This is a hacky way of making keras stop complaining about inputs not connected to outputs
            
            outputs = [moved_image, ddf, moved_label]
            
            registration_model = keras.Model(inputs=inputs, outputs=outputs)
            
            print('\nRegistration network inputs and outputs:')
            
            print('    input shape: ', ', '.join([str(t.shape) for t in registration_model.inputs]))
            print('    output shape:', ', '.join([str(t.shape) for t in registration_model.outputs]))
    
            if similarity_metric == "NCC" and weak_supervision == False:
                losses = [vxm.losses.NCC().loss, vxm.losses.Grad('l2').loss, vxm.losses.Dice().loss]
                loss_weights = [1, lambda_param, 0]
            elif similarity_metric == "NCC" and weak_supervision == True:
                losses = [vxm.losses.NCC().loss, vxm.losses.Grad('l2').loss, vxm.losses.Dice().loss]
                loss_weights = [0, lambda_param, 1]
            elif similarity_metric == "MSE" and weak_supervision == False:
                losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss, vxm.losses.Dice().loss]
                loss_weights = [0, lambda_param, 0]
            elif similarity_metric == "MSE" and weak_supervision == True:
                losses = [vxm.losses.MSE().loss, vxm.losses.Grad('l2').loss, vxm.losses.Dice().loss]
                loss_weights = [1, lambda_param, 1]
            
            registration_model.compile(optimizer='Adam', loss=losses, loss_weights=loss_weights)
    
            # =========================================================================
            # Resume from last trial
            # =========================================================================
    
            if latest_weights is not None:
                print(f"Resuming from {latest_weights} (trial {last_trial})")
                registration_model.load_weights(latest_weights)
                start_trial = last_trial + 1
    
                # Load arrays
                val_dice = np.load(os.path.join(model_save_path, "val_dice.npy")).tolist()
                transformer_losses = np.load(os.path.join(model_save_path, "transformer_losses.npy")).tolist()
                losses = np.load(os.path.join(model_save_path, "losses.npy")).tolist()
                conv3d_losses = np.load(os.path.join(model_save_path, "conv3d_losses.npy")).tolist()
                
                # Index from 0 to trial_num
                val_dice = val_dice[:last_trial + 1]
                transformer_losses = transformer_losses[:last_trial + 1]
                losses = losses[:last_trial + 1]
                conv3d_losses = conv3d_losses[:last_trial +1]
            
            else:
                print("No previous weights found, starting fresh")
                start_trial = 0
                val_dice = []
                losses = []
                transformer_losses = []
                conv3d_losses = []
            
            # =============================================================================
            # Training loop
            # =============================================================================
            
            batch_size = 8 # Decrease this if you are running out of RAM - set to 4, 8, 16 ect.
            
            train_gen = train_generator_(train_cache, batch_size, moving_image_shape, fixed_image_shape, with_label_inputs=True)
            
            num_trials = 1024 # This may be way above what we require
            
            for trial in range(start_trial, num_trials):
                print(f'\nTrial {trial} / {num_trials-1}:')
            
                hist = registration_model.fit(train_gen, epochs=1, steps_per_epoch=32, verbose=1);
    
                dice_scores = []
                
                for label_num in range(6):
                    val_gen = test_generator(test_cache, 4, moving_image_shape, fixed_image_shape, start_index=None, end_index=None, label_num=label_num, with_label_inputs=True)
                    val_count = 0
                    while True:
                        try:
                            (val_inputs, val_outputs) = next(val_gen)
                            moving_images_val, fixed_images_val, moving_labels_val, fixed_labels_val = val_inputs
                            fixed_images_val, zero_phis_val, fixed_labels_val = val_outputs
                            _, ddf_val, _ = registration_model.predict((moving_images_val, fixed_images_val, moving_labels_val, fixed_labels_val), verbose=0)
            
                            moved_labels_val = spatial_transformer([moving_labels_val, ddf_val])
                            moved_images_val = spatial_transformer([moving_images_val, ddf_val])
            
                            if label_num == 0 and val_count == 0:
                                # Moving, moved and fixed image slice indices
                                slice_index_moving = 32#moving_images_val.shape[0] // 2
                                slice_index_fixed = 32#fixed_images_val.shape[0] // 2
                                print("-",moving_images_val.shape)
                                moving_images_val_slice = tf.squeeze(moving_images_val[0])[slice_index_moving,:,:]
                                fixed_images_val_slice = tf.squeeze(fixed_images_val[0])[slice_index_fixed,:,:]
                                moved_images_val_slice = tf.squeeze(moved_images_val[0])[slice_index_moving,:,:]
                
                                # Moving, moved and fixed label slice indices
                                slice_index_moving = 32#moving_labels_val.shape[0] // 2
                                slice_index_fixed = 32#fixed_labels_val.shape[0] // 2
                                moving_labels_val_slice = tf.squeeze(moving_labels_val[0])[slice_index_moving,:,:]
                                fixed_labels_val_slice = tf.squeeze(fixed_labels_val[0])[slice_index_fixed,:,:]
                                moved_labels_val_slice = tf.squeeze(moved_labels_val[0])[slice_index_moving,:,:]
                
                                # VISUALLY CHECK FIXED, MOVING MOVED IMAGES
                                print("*", tf.squeeze(moved_images_val_slice).shape)
                                print(moving_images_val.shape)
                                plt.subplot(1,3,1)
                                plt.imshow(tf.squeeze(moved_images_val_slice), cmap='gray')
                                plt.title('Moved Image')
                
                                plt.subplot(1,3,2)
                                plt.imshow(tf.squeeze(moving_images_val_slice), cmap='gray')
                                plt.title('Moving Image')
                
                                plt.subplot(1,3,3)
                                plt.imshow(tf.squeeze(fixed_images_val_slice), cmap='gray')
                                plt.title('Fixed Image')
                                # plt.show()
                                save_name_img = os.path.join(model_save_path, f"image_slices_trial_{trial}.png")
                                plt.savefig(save_name_img, dpi=300, bbox_inches='tight')
                                plt.show()
                                plt.close()
                                print(f"Saved image slices to {save_name_img}")
                
                                # VISUALLY CHECK LABELS
                                plt.subplot(1,3,1)
                                plt.imshow(tf.squeeze(moved_labels_val_slice), cmap='gray')
                                plt.title('Moved Label')
                                
                                plt.subplot(1,3,2)
                                plt.imshow(tf.squeeze(moving_labels_val_slice), cmap='gray')
                                plt.title('Moving Label')
                                
                                plt.subplot(1,3,3)
                                plt.imshow(tf.squeeze(fixed_labels_val_slice), cmap='gray')
                                plt.title('Fixed Label')
                                # plt.show()
                                save_name_lbl = os.path.join(model_save_path, f"label_slices_trial_{trial}.png")
                                plt.savefig(save_name_lbl, dpi=300, bbox_inches='tight')
                                plt.show()
                                plt.close()
                                print(f"Saved label slices to {save_name_lbl}")
            
                                ddf = ddf_val[1].squeeze()  # Remove batch/channel dims -> (64, 64, 64, 3)
                                mid_z = ddf.shape[0] // 2
                                mid_plane = ddf[mid_z, :, :, :]  # Take middle z-slice -> (64, 64, 3)
                                
                                # Downsample for cleaner visualization
                                flow = mid_plane[::1, ::1]  # Downsample to (21, 21, 3) if you want. I don't downsample at all
                                
                                # Visualize (showing x,y components only - drop z-component for 2D plot)
                                #flow = np.stack([-flow[0][..., 1], -flow[0][..., 0]], axis=-1)
                                
                                # plt.figure(figsize=(6,6))
                                #ne.plot.flow([flow[..., :2]], width=5);  # Only show x,y vectors
                                # save_name_flow = os.path.join(model_save_path, f"flow_field_trial_{trial}.png")
                                # plt.savefig(save_name_flow, dpi=300, bbox_inches='tight')
                                # plt.close()
                                # print(f"Saved flow field to {save_name_flow}")
                                ddf = ddf_val[0].squeeze()  # Remove batch/channel dims -> (64, 64, 64, 3) # WAS 1
                                mid_z = ddf.shape[0] // 2
                                mid_plane = ddf[mid_z, :, :, :]  # Take middle z-slice -> (64, 64, 3)
                                
                                # Downsample for cleaner visualization
                                flow = mid_plane[::1, ::1]  # Downsample to (21, 21, 3)
                                
                                # Visualize (showing x,y components only - drop z-component for 2D plot)
                                #flow = np.stack([-flow[0][..., 1], -flow[0][..., 0]], axis=-1)
                                ne.plot.flow([flow[..., :2]], width=5);  # Only show x,y vectors
        
                                val_count += 1
                            
                            dice_score = np.array(-1.0 * vxm.losses.Dice().loss(tf.convert_to_tensor(moved_labels_val, dtype='float32'), tf.convert_to_tensor(fixed_labels_val, dtype='float32')))
                            dice_scores.append(dice_score)

                            print('.', end='')
                        except (IndexError, StopIteration) as e:
                            break
    
                values = [arr.item() for arr in dice_scores]
                print(values)
                losses.append(hist.history["loss"][0])
                transformer_losses.append(hist.history["transformer_loss"][0])
                conv_loss_key = next(
                    (key for key in hist.history.keys() 
                     if key.startswith('conv3d_') and key.endswith('_loss')),
                    None
                )                
                conv3d_losses.append(hist.history[conv_loss_key][0]) 
                val_dice.append(np.mean(dice_scores))
        
                if True:
                    plt.figure(figsize=(12, 4))
                    plt.subplot(1, 3, 1)
                    plt.subplots_adjust(wspace=0.5) 
                    
                    plt.plot(losses, label="Total Loss")
                    plt.plot(transformer_losses, label="Regularization Loss")
                    plt.xlabel('Trials')
                    plt.ylabel('Losses')
                    plt.legend()  
            
                    plt.subplot(1, 3, 2)
                    plt.plot(conv3d_losses, label="Regularization Loss")
                    plt.xlabel('Trials')
                    plt.ylabel('Losses')
                    plt.legend()  
                                         
                    plt.subplot(1, 3, 3)
                    plt.plot(val_dice, 'r')
                    plt.xlabel('Trials')
                    plt.ylabel('Dice')
                    #plt.savefig(r'voxelmorph_val_dice_1.png')
                    #plt.show()
                    # Save figure with unique filename
                    save_name = os.path.join(model_save_path, f"training_curves_trial_{trial}.png")
                    plt.savefig(save_name, dpi=300, bbox_inches='tight')
                    plt.close()
                    print(f"Saved plot to {save_name}")
                
                print('    Validation Dice: ', np.mean(dice_scores))
                    
    train_model(intensity_metric,weakly_supervised,Verbose)

In [None]:
for weak_supervision in [[True,"unsupervised"]]:
    for intensity_metric in ["NCC"]:
        for data in [["nifti_data_preprocessed/train", "nifti_data_preprocessed/val"], ["nifti_data/train", "nifti_data/val"]]:
        
            if data[0] == "nifti_data/train":
                name = f"{intensity_metric}_{weak_supervision[1]}_no_preprocessing"
            elif data[0] == "nifti_data_preprocessed/train":
                name = f"{intensity_metric}_{weak_supervision[1]}_affine_preprocessing"
                
            run_experiments(name,intensity_metric, weak_supervision[0], data[0], data[1])

In [None]:
for weak_supervision in [[True,"unsupervised"]]:
    for intensity_metric in ["NCC"]:
        for data in [["nifti_data/train", "nifti_data/val"]]:
        
            if data[0] == "nifti_data/train":
                name = f"{intensity_metric}_{weak_supervision[1]}_no_preprocessing"
            elif data[0] == "nifti_data_preprocessed/train":
                name = f"{intensity_metric}_{weak_supervision[1]}_affine_preprocessing"
                
            run_experiments(name,intensity_metric, weak_supervision[0], data[0], data[1])

In [None]:
for weak_supervision in [[False,"unsupervised"]]:
    for intensity_metric in ["NCC"]:
        for data in [["nifti_data/train", "nifti_data/val"]]:
        
            if data[0] == "nifti_data/train":
                name = f"{intensity_metric}_{weak_supervision[1]}_no_preprocessing"
            elif data[0] == "nifti_data_preprocessed/train":
                name = f"{intensity_metric}_{weak_supervision[1]}_affine_preprocessing"
                
            run_experiments(name,intensity_metric, weak_supervision[0], data[0], data[1])

In [None]:
for weak_supervision in [[False,"unsupervised"]]:
    for intensity_metric in ["NCC"]:
        for data in [["nifti_data/train", "nifti_data/val"], ["nifti_data_preprocessed/train", "nifti_data_preprocessed/val"]]:
        
            if data[0] == "nifti_data/train":
                name = f"{intensity_metric}_{weak_supervision[1]}_no_preprocessing"
            elif data[0] == "nifti_data_preprocessed/train":
                name = f"{intensity_metric}_{weak_supervision[1]}_affine_preprocessing"
                
            run_experiments(name,intensity_metric, weak_supervision[0], data[0], data[1])