# imports

In [None]:
import sys
sys.path.append(r'\\192.168.10.106\imdea\DataDriven_UT_AlbertoVicente\10_code\UTvsXCT-preprocessing')
from preprocess_tools import onlypores , datasetmaker, io, aligner, register, reslicer
import numpy as np
from dbtools import dbtools as db
from dbtools import load as load
import pandas as pd
from pathlib import Path
import ast

# Database conection

In [None]:
try:
    conn = db.connect()
    print("Connected to the database")
except Exception as error:
    print(error)

# Select measurements to create the datasets

## measurement type 

In [None]:
measurementtypes_table = db.get_data_metadata('measurementtypes')

measurementtypes_table

In [None]:
ut_type = 7

xct_type = 6

## Selecting UT measurements

In [None]:
ut_measurements_table = db.get_data_metadata('measurements')

In [None]:
ut_measurements_table = ut_measurements_table[ut_measurements_table['measurementtype_id_measurement'] == ut_type]

ut_measurements_table

## Selecting XCT measurements

In [None]:
xct_measurements_table = db.get_data_metadata('measurements')

In [None]:
xct_measurements_table = xct_measurements_table[xct_measurements_table['measurementtype_id_measurement'] == xct_type]

xct_measurements_table

## Getting registered pairs

In [None]:
registrations_table = db.get_data_metadata('registrations')

registrations_table = registrations_table[registrations_table['reference_measurement_id_registration'].isin(ut_measurements_table['id_measurement'])]

registrations_table = registrations_table[registrations_table['registered_measurement_id_registration'].isin(xct_measurements_table['id_measurement'])]

registrations_table

In [None]:
reference_measurements = []
registered_measurements = []
registration_ids = []

for index, row in registrations_table.iterrows():

    reference_id = row['reference_measurement_id_registration']
    registered_id = row['registered_measurement_id_registration']
    registration_id = row['id_registration']

    reference_measurements.append(ut_measurements_table[ut_measurements_table['id_measurement'] == reference_id].iloc[0])
    registered_measurements.append(xct_measurements_table[xct_measurements_table['id_measurement'] == registered_id].iloc[0])
    registration_ids.append(registration_id)

print(f"Found {len(reference_measurements)} registrations")

## Sample names

In [None]:
sample_measurements_table = db.relation_metadata('samples','measurements','sample_measurements')

sample_measurements_table

In [None]:
sample_names = []

for i in range(len(reference_measurements)):

    measurement = reference_measurements[i]
    sample_names.append(sample_measurements_table[sample_measurements_table['id_measurement'] == measurement['id_measurement']].iloc[0]['name_sample'])

print(f"Found {len(sample_names)} samples")

# Datasettype selection

In [None]:
datasettype_table = db.get_data('datasettypes')

datasettype_table

In [None]:
datasettype = 3

# Discard already computed datasets

In [None]:
try:
    dataset_registrations_table = db.relation_metadata('datasets','registrations','dataset_registrations')

    dataset_registrations_table = dataset_registrations_table[dataset_registrations_table['datasettype_id_dataset'] == datasettype]

    dataset_registrations = dataset_registrations_table['id_registration'].values

except Exception as e:
    print("No dataset registrations found or error occurred:", e)
    dataset_registrations = []

# Saving folder

In [None]:
folder = Path(r'\\192.168.10.106\imdea\DataDriven_UT_AlbertoVicente\04_ML_data\Airbus\Panel Pegaso\2025 dataset')

# Patch sizes

In [None]:
patch_sizes = [3, 5, 7, 9]

# Resolutions

In [None]:
xct_resolution = float(measurementtypes_table[measurementtypes_table['id_measurementtype'] == xct_type]['voxel_size_measurementtype'].values[0].split(' ')[0])
ut_resolution = float(measurementtypes_table[measurementtypes_table['id_measurementtype'] == ut_type]['x_resolution_measurementtype'].values[0].split(' ')[0])

# Dataset Generation

In [None]:
for i in range(len(registration_ids)):
    reference_measurement_path = reference_measurements[i]['file_path_measurement']
    registered_measurement_path = registered_measurements[i]['file_path_measurement']
    registration_id = registration_ids[i]
    sample_name = sample_names[i]

    if registration_id in dataset_registrations:
        print(f"Dataset for registration {registration_id} already exists, skipping...")
        continue

    print(f"Creating dataset for registration {registration_id} with reference measurement {reference_measurement_path} and registered measurement {registered_measurement_path}")

    #load the reference and registered measurements
    ut_volume = io.load_tif(reference_measurement_path)
    xct_volume = io.load_tif(registered_measurement_path)

    #xct_volume is z,y,x, so we need to transpose it to z,y,z
    xct_volume = np.transpose(xct_volume, (1, 2, 0))
    #ut_volume is z,y,x, so we need to transpose it to z,y,z
    ut_volume = np.transpose(ut_volume, (1, 2, 0))

    #apply the registration to the xct volume
    registration_parameters = registrations_table[registrations_table['id_registration'] == registration_id]['registration_matrix_registration'].values[0]
    registration_parameters = np.array(ast.literal_eval(registration_parameters))
    
    xct_volume = register.apply_registration(ut_volume,xct_volume,registration_parameters,ut_resolution,xct_resolution,parallel=True)

    #get the frontwall and backwall of the xct volume
    _,frontwall,backwall = aligner.crop_walls(xct_volume)

    #back to z,y,x
    xct_volume = np.transpose(xct_volume, (2, 0, 1))
    ut_volume = np.transpose(ut_volume, (2, 0, 1))

    #compute the onlypores
    onlypores_volume,material_mask,_ = onlypores.onlypores(xct_volume,frontwall,backwall,min_size_filtering=8)

    #create a folder for the dataset
    dataset_folder = folder / f"{sample_name}"
    dataset_folder.mkdir(parents=True, exist_ok=True)

    #compute the datasets
    
    for patch_size in patch_sizes:

        

        reconstruction_shape,df = datasetmaker.main(onlypores_volume,material_mask,ut_volume,xct_resolution,ut_resolution,ut_patch_size=patch_size, ut_step_size=1)

        #save the dataset
        dataset_path = dataset_folder / f"patch_size_{patch_size}_volfrac_areafrac.csv"
        df.to_csv(dataset_path, index=False)

        #save into the database
        rows = len(df)
        targets = ['volfrac','areafrac']
        description = 'Created with the production notebook in preprocess tools v 0.1.17'

        load.load_dataset(conn,
                          datasettype_id=datasettype,
                          file_path=str(dataset_path),
                          rows=rows,
                          patch_size=str(patch_size),
                          targets=targets,
                          reconstruction_shape=reconstruction_shape,
                          registration_ids = [registration_id],
                          description=description
                          )

        print(f"Dataset for registration {registration_id} with patch size {patch_size} saved to {dataset_path}")