In [None]:
import importlib
import numpy as np
import matplotlib.pyplot as plt
import h5py

%load_ext autoreload
%autoreload 2

In [None]:
from utils.load import extract_spheroid_id_and_day_idx, load_srs_data, load_dc_data
from process.denoising import wavelet_denoise_3d, pca_denoise_3d
from process.segmentation import segmentation_pipeline
from process.calculate_statistics import calculate_spheroid_statistics

h5_path = "DOE_11/Seed_300/Plate_80_A1_D1.h5"
spheroid_id, day_idx = extract_spheroid_id_and_day_idx(h5_path)
print(spheroid_id, day_idx)
# Load data
srs_data = load_srs_data(h5_path)
dc_data = load_dc_data(h5_path)

srs_image = srs_data['data']
dc_image = dc_data['data']

# Get pixel size from metadata
pixel_size_um = srs_data['attributes']['Pix_size (x,y)']

print(f"SRS image shape: {srs_image.shape}")
print(f"DC image shape: {dc_image.shape}")
print(f"Pixel size: {pixel_size_um} μm")


denoised_srs_image_2855 = pca_denoise_3d(srs_image[2], n_components_ratio=0.8, standardize=True)
denoised_dc_image_2855 = pca_denoise_3d(dc_image[2], n_components_ratio=0.8, standardize=True)
# denoised_srs_image_2855 = wavelet_denoise_3d(srs_image[2], wavelet='db4', level=None, sigma=None, mode='soft')
# denoised_dc_image_2855 = wavelet_denoise_3d(dc_image[2], wavelet='db4', level=None, sigma=None, mode='soft')
mask_3d, slice_metrics, processed_stack_srs, processed_stack_dc = segmentation_pipeline(denoised_srs_image_2855, denoised_dc_image_2855, pixel_size_um, spheroid_id, visualize_output_path=None, show_plot=True)


stats = calculate_spheroid_statistics(srs_image, mask_3d)

# Plot statistics
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

wavelengths = range(len(stats['mean']['3d']))

ax1.plot(wavelengths, stats['mean']['3d'], 'o-', linewidth=2, markersize=6)
ax1.set_xlabel('Wavelength Index')
ax1.set_ylabel('Mean Intensity')
ax1.set_title('Mean SRS Intensity by Wavelength')
ax1.grid(True, alpha=0.3)

ax2.plot(wavelengths, stats['std']['3d'], 's-', color='orange', linewidth=2, markersize=6)
ax2.set_xlabel('Wavelength Index')
ax2.set_ylabel('Standard Deviation')
ax2.set_title('SRS Intensity Std Dev by Wavelength')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()