## Atlas Registration

This Notebook will assist with registering every atlas case to every other atlas case. This is useful to conduct a leave-one-out analysis.

In [None]:
import os
import sys
import gc
import re
import time

sys.path.append('../../..')

import pandas as pd

import SimpleITK as sitk

from loguru import logger

# Format the output a bit nicer for Jupyter
logger.remove()
logger.add(sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} {level} {message}", level="DEBUG")

from impit.segmentation.atlas.registration import initial_registration, fast_symmetric_forces_demons_registration, transform_propagation, apply_field

data_path = './data'
working_path = "./working"
if not os.path.exists(working_path):
    os.makedirs(working_path)

# Read the data into a dictionary

data = {}

for root, dirs, files in os.walk(data_path, topdown=False):
    
    if root == data_path:
        continue
        
    case = root.split('/')[-1]
    data[case] = {}
    for f in files:
        file_path = os.path.join(root, f)
        
        name = f.split('.')[0].upper()
        
        # Clean up names with double underscore:
        name = name.replace('__','_')
        
        observer = None
        
        matches = re.findall(r"(.*)_([0-9])", f.split('.')[0])
        
        if len(matches) > 0:
            name = matches[0][0].upper()
            observer = matches[0][1]
        
        if observer: 
            if name in data[case]:
                data[case][name][observer] = file_path
            else:
                data[case][name] = {observer: file_path}
                
        else:
            data[case][name] = file_path



### Register each case to every other case

Used for leave-one-out atlas analysis later

In [None]:
# Log to file to avoid large amounts of output in Notebook
logger.remove()
logger_handler = logger.add("./logs/file_{time}.log", format="{time:YYYY-MM-DD HH:mm:ss} {level} {message}", level="DEBUG")
print("Logging to file, check 'logs' directory")

for i in data:

    logger.info(f'Will run case: {i}')   
    for j in data:

        if i == j: continue

        start = time.time()

        fixed_image = sitk.ReadImage(data[i]['CT'])
        moving_image = sitk.ReadImage(data[j]['CT'])

        reg_dir = os.path.join(working_path, f"{i}_{j}")
        if not os.path.exists(reg_dir):
            os.makedirs(reg_dir)

        # Rigid Reg
        logger.info(f'Rigidly register case {j} (moving) to {i} (fixed)')
        ct_reg_rig, rigid_transform = initial_registration(fixed_image, moving_image)
        ct_reg_rig_path = os.path.join(reg_dir, f'{j}_ct_reg_rig.nii.gz')
        sitk.WriteImage(ct_reg_rig, ct_reg_rig_path)

        # Deformable Reg
        logger.info(f'Deformably register case {j} (moving) to {i} (fixed)')
        ct_reg_def, deform_field = fast_symmetric_forces_demons_registration(fixed_image, ct_reg_rig, resolution_staging=[16,4,2,1], iteration_staging=[2,2,2,2], ncores=12)
        ct_reg_def_path = os.path.join(reg_dir, f'{j}_ct_reg_def.nii.gz')
        sitk.WriteImage(ct_reg_def, ct_reg_def_path)

        for s in data[j]:
            if not s.upper().startswith('STRUCT'):
                continue
                
            for o in data[j][s]:
                
                logger.info(f'Deforming structure: {s}')
                
                structure_path = data[j][s][o]

                structure = sitk.ReadImage(structure_path)
                structure_transformed = transform_propagation(fixed_image, structure, rigid_transform, structure=True)
                structure_deformed = apply_field(structure_transformed, deform_field, structure=True)
                structure_file = os.path.join(reg_dir, structure_path.split('/')[-1].replace('Struct_', f'Struct_def_{j}_'))
                sitk.WriteImage(structure_deformed, structure_file)

        end = time.time()
        print(f'Took: {end-start}')
    break

# Revert to logging to output
logger.remove()
logger.add(sys.stdout, format="{time:YYYY-MM-DD HH:mm:ss} {level} {message}", level="DEBUG")

### Read in registered structured to data dictionary

In [None]:

for k in data:
    data[k]['moving'] = {}    
    
for root, dirs, files in os.walk(working_path, topdown=False):
        
    parts  =  re.findall('([0-9]+)_([0-9]+)$',root)
    
    if len(parts) < 1:
        continue
        
    fixed = parts[0][0]
    moving = parts[0][1]
    
    if not moving in data[fixed]['moving'].keys():
        data[fixed]['moving'][moving] = {}

    for f in files:
        
        structure_key = f.replace(f'Struct_def_{moving}', 'Struct').split('.')[0]
        
        # Clean up names with double underscore:
        structure_key = structure_key.replace('__','_')
            
        if 'OLD' in structure_key:
            continue

        file_path = os.path.join(root, f)
        
        if '_ct_reg_def' in f:
            data[fixed]['moving'][moving]['CT'] = file_path
            
        if not 'Struct' in structure_key:
            continue
            
        name = structure_key.upper()
        observer = None
        
        matches = re.findall(r"(.*)_([0-9])", structure_key)
        
        if len(matches) > 0:
            name = matches[0][0].upper()
            observer = matches[0][1]
        
        if observer: 
            if name in data[fixed]['moving'][moving]:
                data[fixed]['moving'][moving][name][observer] = file_path
            else:
                data[fixed]['moving'][moving][name] = {observer: file_path}
                
        else:
            data[fixed]['moving'][moving][name] = file_path
