In [None]:
import tifffile
import numpy as np
from tqdm import tqdm
from pathlib import Path
import os
import pandas as pd
import sys
sys.path.append(r'\\192.168.10.106\imdea\DataDriven_UT_AlbertoVicente\10_code\UTvsXCT-preprocessing')
from dbtools import dbtools as db
from dbtools import load
from preprocess_tools import io, aligner, reslicer, signal, register

# Database conection

In [None]:
try:
    conn = db.connect()
    print("Connected to the database")
except Exception as error:
    print(error)

# Measurement type id

Select the measurement type of the monoelement and the xct

In [None]:
measurementtypes_table = db.get_data_metadata('measurementtypes')

measurementtypes_table

In [None]:
reference_measurementtype_id = 5
registered_measurementtype_id = 6

In [None]:
#get the resolutions of each type
reference_resolution = 1
registered_resolution = measurementtypes_table[measurementtypes_table['id_measurementtype'] == registered_measurementtype_id]['voxel_size_measurementtype'].values[0]

registered_resolution = float(registered_resolution.split(' ')[0])

print(f"Reference resolution: {reference_resolution}"
      f"\nRegistered resolution: {registered_resolution}")

# Measurement filtering

We have to filter the measurements to select which of them to register.

## Select samples

We select the panel

In [None]:
panels_table = db.get_data_metadata('panels')

panels_table

In [None]:
panel_id = 3

Get all the samples from the selected panel.

Get only the ones with keyholes.

Get only the id and name of the samples.

In [None]:
samples_table = db.get_data_metadata('samples')

samples_table = samples_table[samples_table['panel_id_sample'] == panel_id]

samples_table = samples_table[samples_table['keyhole_sample'] == 'True bool']

samples_table = samples_table[['id_sample', 'name_sample']]

samples_table

## select measurements

We have to filter the measurements

### Monoelement filtering

In [None]:
measurements_table_monoelement = db.relation_metadata('measurements', 'samples','sample_measurements')

measurements_table_monoelement = measurements_table_monoelement[measurements_table_monoelement['measurementtype_id_measurement'] == reference_measurementtype_id]

#drop NA columns
measurements_table_monoelement = measurements_table_monoelement.dropna(axis=1, how='all')

#get the measurements from the samples
measurements_table_monoelement = measurements_table_monoelement[measurements_table_monoelement['id_sample'].isin(samples_table['id_sample'])]

#group by file_path_measurement
def agg_func(x):
    if len(x.unique()) > 1:
        return list(x.unique())
    else:
        return x.iloc[0]

measurements_table_monoelement = measurements_table_monoelement.groupby(
    ['file_path_measurement'], dropna=False
).agg(agg_func).reset_index()

#get the measurements from the samples
measurements_table_monoelement = measurements_table_monoelement[measurements_table_monoelement['id_sample'].isin(samples_table['id_sample'])]

measurements_table_monoelement = measurements_table_monoelement[['file_path_measurement','id_measurement','id_sample','name_sample']]

measurements_table_monoelement

### xct filtering

In [None]:
measurements_table_xct = db.relation_metadata('measurements', 'samples','sample_measurements')

measurements_table_xct = measurements_table_xct[measurements_table_xct['measurementtype_id_measurement'] == registered_measurementtype_id]

#drop NA columns
measurements_table_xct = measurements_table_xct.dropna(axis=1, how='all')

#get the measurements from the samples
measurements_table_xct = measurements_table_xct[measurements_table_xct['id_sample'].isin(samples_table['id_sample'])]

#group by file_path_measurement
def agg_func(x):
    if len(x.unique()) > 1:
        return list(x.unique())
    else:
        return x.iloc[0]

measurements_table_xct = measurements_table_xct.groupby(
    ['file_path_measurement'], dropna=False
).agg(agg_func).reset_index()

#get the measurements from the samples
measurements_table_xct = measurements_table_xct[measurements_table_xct['id_sample'].isin(samples_table['id_sample'])]

#get the equalized measurements
measurements_table_xct = measurements_table_xct[measurements_table_xct['equalized_measurement'] == 'True bool']

#get the aligned measurements
measurements_table_xct = measurements_table_xct[measurements_table_xct['aligned_measurement'] == 'True bool']


# measurements_table_xct = measurements_table_xct[['file_path_measurement','id_measurement','id_sample','name_sample']]

measurements_table_xct[['file_path_measurement','id_measurement','id_sample','name_sample']]

## Get existing registrations

In [None]:
registration_type = 'Alberto 2024 registration method, UTvsXCTPreprocessing toolkit 0.1.14 , file register.py function register_ut_xct_monoelement. Extract the centers of the holes from UT and XCT, and register them using a rigid body transformation text'

registrations_table = db.get_data_metadata('registrations')

try:
    registrations_table = registrations_table[registrations_table['type_registration'] == registration_type]

    registrations_table

    available_registrations = True

except Exception as e:
    print(f"No registrations available for type: {registration_type}")
    available_registrations = False


# Registering

In [None]:
for i in range(len(samples_table)):
        
    try:

        print(f"Processing sample {i+1}/{len(samples_table)}: {samples_table['name_sample'].iloc[i]}")

        #get the measurements for the sample
        measurements_table_reference_sample = measurements_table_monoelement[measurements_table_monoelement['id_sample'] == samples_table['id_sample'].iloc[i]]

        if measurements_table_reference_sample.empty:
            print(f"No reference measurements found for sample {samples_table['name_sample'].iloc[i]}")
            continue

        #get the measurements for the sample
        measurements_table_registered_sample = measurements_table_xct[measurements_table_xct['id_sample'] == samples_table['id_sample'].iloc[i]]

        if measurements_table_registered_sample.empty:
            print(f"No registered measurements found for sample {samples_table['name_sample'].iloc[i]}")
            continue
        
        #check if the sample has already been registered
        if available_registrations:
            if measurements_table_reference_sample['id_measurement'].iloc[0] in registrations_table['reference_measurement_id_registration'].values:
                if measurements_table_registered_sample['id_measurement'].iloc[0] in registrations_table['registered_measurement_id_registration'].values:
                    print(f"Sample {samples_table['name_sample'].iloc[i]} has already been registered.")
                    continue

        #get the paths
        reference_measurement_path = Path(measurements_table_reference_sample['file_path_measurement'].iloc[0])
        registered_measurement_path = Path(measurements_table_registered_sample['file_path_measurement'].iloc[0])

        #load the reference measurement
        rf = io.load_tif(reference_measurement_path)

        #load the registered measurement
        xct = io.load_tif(registered_measurement_path)

        #create a folder named aux_registration in the same folder as the registered file
        aux_registration_folder = registered_measurement_path.parent / 'aux_registration'
        aux_registration_folder.mkdir(parents=True, exist_ok=True)

        #convert rf from z,y,x to x,y,z
        rf = np.transpose(rf, (2, 1, 0))

        #UT preprocessing
        amp = signal.envelope(rf)

        rf = np.transpose(amp, (2, 1, 0))  # Ensure rf is in the correct shape for registration
        amp = np.transpose(amp, (2, 1, 0))  # Ensure amp is in the correct shape for registration
        
        #register
        parameters,ut_centers,xct_centers, transformed = register.register_ut_xct_monoelement(amp,xct,reference_resolution,registered_resolution)

        # Save the aux files
        aux_output_file = aux_registration_folder / 'centers.tif'
        io.save_tif(aux_output_file, ut_centers)
        aux_output_file = aux_registration_folder / 'centers_xct.tif'
        io.save_tif(aux_output_file, xct_centers)
        aux_output_file = aux_registration_folder / 'transformed_centers_xct.tif'
        io.save_tif(aux_output_file, transformed)

        #estimate the registration error
        mse = np.mean(((ut_centers > 0)*1 - (transformed> 0)*1) ** 2)
        print(f"Mean Squared Error (MSE): {mse}")

        #save into the database
        axes = ['x', 'y']

        parameters_list = []

        for i in range(len(parameters)):
            aux_list = []
            for j in range(len(parameters[i])):
                aux_list.append(float(parameters[i][j]))
            parameters_list.append(aux_list)

        load.load_registration(conn,parameters_list,str(reference_measurement_path),str(registered_measurement_path),registration_type,axes)
    
    except Exception as e:
        print(f"Error processing sample {samples_table['name_sample'].iloc[i]}: {e}")
        continue