### Set up params

In [2]:
import os
import time
import numpy as np
from ext.lab2im import utils
from SynthSR.brain_generator import BrainGenerator
import SimpleITK as sitk

# folder containing label maps to generate images from
labels_folder = 'data/labels'
# folder containing corresponding images, that will be used as target regression
images_folder = 'data/images'

# result parameters
n_examples = 3  # number of generated examples
result_dir = 'generated_images'  # folder where they will be saved

In [3]:
# general parameters
# We now generate 2 synthetic channels, which will both be used as input. Note that it only contains True values, since
# we use real scans as regeression target. Bear in mind that input_channels onyl refers to synthetic channels (it never
# includes the real regression target).
input_channels = [True, True]
output_channel = None  # the regression targets are not synthetic, but real
target_res = None  # produce data at the resolution of the label maps
output_shape = 283  # randomly crop to 128^3

# label values of structure to generate from
generation_labels = 'data/labels_classes_priors/generation_labels.npy'
# classes associating similar structures to the same Gaussian distribution
generation_classes = 'data/labels_classes_priors/generation_classes.npy'

In [4]:

# Hyperparameters governing the GMM priors for the synthetic T1 and T2 scans. Note that T1s will be the the first
# synthetic channel (as we provide t1 hyperparameters first).
prior_means_t1_lr = np.load('data/labels_classes_priors/prior_means_t1_lr.npy')
prior_means_t2 = np.load('data/labels_classes_priors/prior_means_t2.npy')
prior_means = np.concatenate([prior_means_t1_lr, prior_means_t2], axis=0)
prior_stds_t1_lr = np.load('data/labels_classes_priors/prior_stds_t1_lr.npy')
prior_stds_t2 = np.load('data/labels_classes_priors/prior_stds_t2.npy')
prior_stds = np.concatenate([prior_stds_t1_lr, prior_stds_t2], axis=0)

In [7]:
# augmentation parameters
flipping = False
scaling_bounds = 0
rotation_bounds = 0
shearing_bounds = 0.01
translation_bounds = False
nonlin_std = 10.
bias_field_std = 0.2

# blurring/downsampling parameters
# We assume here that the T1 and T2 LR scans were not acquired at the same resolution/slice thickness. We provide the
# corresponding resolution in the same order as for the hyperparameters. In this example we simulate:
# 3mm coronal T1 with 3mm thickness, and 4mm sagittal T2 with 3mm thickness.
data_res = np.array([[1., 1., 3.], [1., 4.5, 1.]])  # slice spacing
thickness = np.array([[1., 1., 3.], [1., 3., 1.]])  # slice thickness
downsample = False  # downsample to simulated LR
build_reliability_maps = False  # add reliability map to input channels
# In this example we introduce small variations in the blurring kernel, such that the downstream network is robust to
# small changes in acquisition resolution. We provide it here with this coefficient, where the blurring simulates a
# resolution sampled in the uniform distribution U(data_res/blur_range; data_res*blur_range). Therefore blur_range must
# equal to 1 (no changes), or greater than 1.
blur_range = 1.15
# Here we have two input channels, and we want to model registration problems between the two. This may be due to head
# movement between the two acquisitions, or the fact that the two scans were not acquired in the same coordinate space
# (e.g. orthogonal T1, and T2 acquired along the hippocampal axis). This registration error will be simulated with
# respect to the first input channel.
simulate_registration_error = False



In [8]:
brain_generator = BrainGenerator(labels_dir=labels_folder,
                                 images_dir=images_folder,
                                 generation_labels=generation_labels,
                                 input_channels=input_channels,
                                 output_channel=output_channel,
                                 target_res=target_res,
                                 output_shape=output_shape,
                                 generation_classes=generation_classes,
                                 prior_means=prior_means,
                                 prior_stds=prior_stds,
                                 prior_distributions='normal',
                                 flipping=flipping,
                                 scaling_bounds=scaling_bounds,
                                 rotation_bounds=rotation_bounds,
                                 shearing_bounds=shearing_bounds,
                                 translation_bounds=translation_bounds,
                                 simulate_registration_error=simulate_registration_error,
                                 nonlin_std=nonlin_std,
                                 bias_field_std=bias_field_std,
                                 data_res=data_res,
                                 thickness=thickness,
                                 downsample=downsample,
                                 blur_range=blur_range,
                                 build_reliability_maps=build_reliability_maps)

### Generate images

In [9]:
input_channels, regression_target = brain_generator.generate_brain()

In [10]:
input_channels.shape, regression_target.shape

((108, 87, 283, 2), (108, 87, 283))

### Save images

In [None]:
for n in range(n_examples):

    # generate !
    start = time.time()
    input_channels, regression_target = brain_generator.generate_brain()
    end = time.time()
    print('generation {0:d} took {1:.01f}s'.format(n + 1, end - start))

    # save output image and label map
    utils.save_volume(np.squeeze(input_channels[..., 0]), brain_generator.aff, brain_generator.header,
                      os.path.join(result_dir, 't1_input_%s.nii.gz' % (n + 1)))
    utils.save_volume(np.squeeze(input_channels[..., 1]), brain_generator.aff, brain_generator.header,
                      os.path.join(result_dir, 'reliability_map_t1_input_%s.nii.gz' % (n + 1)))
    utils.save_volume(np.squeeze(input_channels[..., 2]), brain_generator.aff, brain_generator.header,
                      os.path.join(result_dir, 't2_input_%s.nii.gz' % (n + 1)))
    utils.save_volume(np.squeeze(input_channels[..., 3]), brain_generator.aff, brain_generator.header,
                      os.path.join(result_dir, 'reliability_map_t2_input_%s.nii.gz' % (n + 1)))
    utils.save_volume(np.squeeze(regression_target), brain_generator.aff, brain_generator.header,
                      os.path.join(result_dir, 't1_target_%s.nii.gz' % (n + 1)))


### Get image params from original image

In [26]:
image = sitk.ReadImage('drive/driveData/Segmentations/cropped.nii.gz')


In [16]:
def get_image_params(image):
    return image.GetOrigin(), image.GetDirection(), image.GetSpacing()

def image_to_array(image):
    return sitk.GetArrayFromImage(image), get_image_params(image)

def array_to_image(array, params):
    image = sitk.GetImageFromArray(array)
    image.SetOrigin(params[0])
    image.SetDirection(params[1])
    image.SetSpacing(params[2])
    return image

In [18]:
params = get_image_params(image)

### View synthesis result

In [135]:
for i in range(2):
    sitk.Show(array_to_image(input_channels[..., i], params))

In [136]:
sitk.Show(array_to_image(regression_target, params))

### Change original images orientation

In [113]:
import SimpleITK as sitk
import numpy as np

label_path = 'drive/driveData/Segmentations'
atlas_path = 'drive/driveData/Atlases'

# label_path = 'drive/Labels'
# atlas_path = 'drive/Atlas'

# seg_volume = sitk.ReadImage(label_path + '/before.nii.gz')
# im_volume = sitk.ReadImage(atlas_path + '/before.nii.gz')

seg_volume = sitk.ReadImage(label_path + '/cropped.nii.gz')
im_volume = sitk.ReadImage(atlas_path + '/cropped.nii.gz')

In [114]:
seg_volume_np, seg_params = image_to_array(seg_volume)
im_volume_np, atlas_params = image_to_array(im_volume)

In [34]:
im_volume_np.shape

(283, 87, 108)

In [116]:
seg_volume_np_sw = np.swapaxes(seg_volume_np, 0, 2)
im_volume_np_sw = np.swapaxes(im_volume_np, 0, 2)

In [117]:
im_volume_flipped = array_to_image(im_volume_np_sw, params)
seg_volume_flipped = array_to_image(seg_volume_np_sw, params)

In [118]:
sitk.Show(seg_volume_flipped)

### Save image object

In [43]:
im_writer = sitk.ImageFileWriter()
im_writer.SetFileName("cropped_augmented.nii.gz")
im_writer.Execute(array_to_image(regression_target, params))