## Imports and initialize TensorFlow

In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd
import spectral
import shelve
import pickle

import matplotlib.pyplot as plt

# --- Configuration ---
BATCH_SIZE = 32
EPOCHS = 600 # Set a large number, EarlyStopping will find the best
LEARNING_RATE = 1e-4
PATIENCE = 30 # For Early Stopping
IGNORE_VALUE = -1 # Integer value to represent 'N' or ignored labels

# Task names (consistent keys/output layer names)
TASK_NAMES = ['plant', 'age', 'part', 'health', 'lifecycle']
OUTPUT_NAMES = [f"{task}_output" for task in TASK_NAMES]

# Loss weights (weighting between tasks, start equal)
LOSS_WEIGHTS = {
    'plant_output': 1.0,
    'age_output': 1.0,
    'part_output': 1.0,
    'health_output': 1.0,
    'lifecycle_output': 1.0
}

print(f"Using TensorFlow version: {tf.__version__}")
# Optional: Configure GPU memory growth if needed
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Configured memory growth for {len(gpus)} GPU(s)")
    except RuntimeError as e:
        print(e)

In [None]:
SCALER_FILEPATH = 'data/pkl/scaler_10-23-25_0.pkl'
CHECKPOINT_FILEPATH = 'data/checkpoints/best_spectral_model.weights_10-23-25_0.h5'
MODEL_FILEPATH = 'data/checkpoints/model_10-23-25_0.keras'

## Load full spectral library

In [None]:
import sys
sys.path.append('util/')
import importlib

import util_scripts as util

import os
from dotenv import load_dotenv

from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

_ = load_dotenv()

MONGO_DBR_URI = os.getenv('MONGO_DBR_URI')

reload_data_driver = False

if (reload_data_driver):

    records = []

    # Create a new client and connect to the server
    client = MongoClient(MONGO_DBR_URI, server_api=ServerApi('1'))

    # Send a ping to confirm a successful connection
    try:
        client.admin.command('ping')
        print("Pinged your deployment. You successfully connected to MongoDB!")
    except Exception as e:
        print(e)

    db = client["upwins_db"]
    view_name = "spectral_library"
    spectral_library = db[view_name]

    records = spectral_library.find()

    df = pd.DataFrame(records)
    df.to_pickle('data/pkl/library_with_Genus_species.pkl')

else:
    df = pd.read_pickle('data/pkl/library_with_Genus_species.pkl')
    #df = pd.read_pickle('data/pkl/library.pkl')

sc = util.SpectralCollection(df)
wl_lib = sc.wl
name = sc.name
spectra = sc.spectra
print(sc.spectra.shape)

## Open Nano imagery

In [None]:
#fname = 'data/morven_4000/raw_4000_or_ref.img'
#fname_hdr = 'data/morven_4000/raw_4000_or_ref.hdr'

fname = 'data/morven_9-2025/raw_55691_or_ref.img'
fname_hdr = 'data/morven_9-2025/raw_55691_or_ref.hdr'

#fname = 'data/5-8-2025/100133_Allied_05_08_2025_2015_06_04_17_50_15/raw_0_ref'
#fname_hdr = 'data/5-8-2025/100133_Allied_05_08_2025_2015_06_04_17_50_15/raw_0_ref.hdr'

# Open the image and read into an array
im = spectral.envi.open(fname_hdr, fname)
wl_img = np.asarray(im.bands.centers)

In [None]:
# OPTIONAL: Load the image into memory
im.Arr = im.load()
print(f'Shape of Im.Arr = {im.Arr.shape}')
im.List = np.reshape(im.Arr, (im.nrows*im.ncols, im.nbands))
print(f'Shape of im.List = {im.List.shape}')

valid_pixel_mask = np.sum(im.List, axis=1)>0

dataList = im.List[valid_pixel_mask, :]
print(f'Shape of dataList = {dataList.shape}')
nr = im.nrows
nc = im.ncols
nb = im.nbands

## OPTIONAL: Get target bands for resampling from ROI (Optional)

In [None]:
roi_filepath = 'data/pkl/rois_labeled/crisfield/Crisfield_October_Training_ROIs_Img14_0513_Genus_spec.pkl'

with open(roi_filepath, 'rb') as f:
    roiData = pickle.load(f)
    roi_df = roiData.df # a DataFrame holding all the data for the ROI

roi_spectra = roi_df.to_numpy()[:,4:]
roi_spectra = roi_spectra.astype(np.float32)

wl_roi = roi_df.columns.to_numpy()[4:]

print("ROI target bands count: ", len(wl_roi))

# set wl_img = to wl_roi for use in the rest of the code
wl_img = wl_roi

## Resample library to match input image

In [None]:
# Create a BandResampler, which is a function that resamples spectra from one source to match a different source.
# See: https://www.spectralpython.net/class_func_ref.html?highlight=resampling#spectral.algorithms.resampling.BandResampler
# Inputs: the first input is the wavelengths for the spectra that you are going to resample, the second input is the wavelengths that you want to resample to.
resampler = spectral.BandResampler(wl_lib, wl_img)
spectra_resampled = resampler(spectra.T).T

print(f'The shape of the spectral library is {spectra_resampled.shape}.')
print(f'({spectra_resampled.shape[0]} spectra with {spectra_resampled.shape[1]} bands.)')

spectra = spectra_resampled

## Prepare ROI data for training

In [None]:
import pickle
import os

def find_roi_files(root_dir):
    string_list = ['.pkl', 'roi']
    
    matching_files = []
    for root, _, files in os.walk(root_dir):
        for filename in files:
            if all(string in filename.lower() for string in string_list):
                matching_files.append(os.path.join(root, filename))
    return matching_files

# Project codes for labeling ROI data
# **IMPORTANT**: ROIs should be named using the **same** naming convention used to label ASD files 

plant_codes = {
    'Ammo_bre': ['Ammophila', 'breviligulata', 'American Beachgrass', 'grass', 'https://en.wikipedia.org/wiki/Ammophila_breviligulata'],
    'Chas_lat': ['Chasmanthium', 'latifolium', 'River Oats', 'grass', 'https://en.wikipedia.org/wiki/Chasmanthium_latifolium'],
    'Pani_ama': ['Panicum', 'amarum', 'Coastal Panic Grass', 'grass', 'https://en.wikipedia.org/wiki/Panicum_amarum'],
    'Pani_vir': ['Panicum', 'virgatum', 'Switch Grass', 'grass', 'https://en.wikipedia.org/wiki/Panicum_virgatum'],
    'Soli_sem': ['Solidago', 'sempervirens', 'Seaside Goldenrod', 'succulent', 'https://en.wikipedia.org/wiki/Chasmanthium_latifolium'],
    'Robi_his': ['Robinia', 'hispida', 'Bristly locust', 'shrub', 'https://en.wikipedia.org/wiki/Robinia_hispida'],
    'More_pen': ['Morella', 'pennsylvanica', 'Bristly locust', 'shrub', 'https://en.wikipedia.org/wiki/Myrica_pensylvanica'],    
    'Rosa_rug': ['Rosa', 'rugosa', 'Sandy Beach Rose', 'shrub', 'https://en.wikipedia.org/wiki/Rosa_rugosa'],
    'Cham_fas': ['Chamaecrista', 'fasciculata', 'Partridge Pea', 'legume', 'https://en.wikipedia.org/wiki/Chamaecrista_fasciculata'],
    'Soli_rug': ['Solidago', 'rugosa', 'Wrinkleleaf goldenrod', 'shrub', 'https://en.wikipedia.org/wiki/Solidago_rugosa'],
    'Bacc_hal': ['Baccharis', 'halimifolia', 'Groundseltree', 'shrub', 'https://en.wikipedia.org/wiki/Baccharis_halimifolia'],
    'Iva_fru_': ['Iva', 'frutescens', 'Jesuits Bark ', 'shrub', 'https://en.wikipedia.org/wiki/Iva_frutescens'],
    'Ilex_vom': ['Ilex', 'vomitoria', 'Yaupon Holly', 'evergreen shrub', 'https://en.wikipedia.org/wiki/Ilex_vomitoria'],
    'Genus_spe': ['Genus', 'species', 'vegetation', 'background', '']
}  
age_codes = {  
    'PE': ['Post Germination Emergence', 'PE'],
	#'RE': ['Re-emergence', 'RE'],
    #'RE': ['Year 1 growth', '1G'],
	#'E': ['Emergence (from seed)', 'E'],
    'E': ['Post Germination Emergence', 'PE'],
	#'D': ['Dormant', 'D'],
	'1G': ['Year 1 growth', '1G'],
    '2G': ['Year 2 growth', '2G'],
	#'1F': ['Year 1 Flowering', '1F'],
    'J': ['Juvenile', 'J'],
	'M': ['Mature', 'M']
}
principal_part_codes = {  
    'MX': ['Mix', 'MX'],
    #'S': ['Seed', 'SE'],
	#'SA': ['Shoot Apex', 'SA'],
    'SA': ['Internode Stem', 'ST'],
	'L': ['Leaf/Blade', 'L'],
	#'IS': ['Internode Stem', 'IS'],
    'ST': ['Internode Stem', 'ST'],
    'SP': ['Sprout', 'SP'],
	#'CS': ['Colar Sprout', 'CS'],
    'CS': ['Sprout', 'SP'],
	#'RS': ['Root Sprout', 'RS'],
    'RS': ['Sprout', 'SP'],
	'LG': ['Lignin', 'LG'],
	'FL': ['Flower', 'FL'],
    #'B': ['Blade', 'B'],
	'B': ['Leaf/Blade', 'L'],
    'FR': ['Fruit', 'FR'],
	#'S': ['Seed', 'SE'], #moved above because 'S' is in other codes; this is an old code
    'SE': ['Seed', 'SE'],
	#'St': ['Stalk', 'St']
}
health_codes = {
    'MH': ['Healthy/Unhealthy Mix', 'MH'],
	'DS': ['Drought Stress', 'DS'],
	'SS': ['Salt Stress (soak)', 'SS'],
    'SY': ['Salt Stress (spray)', 'SY'],
	'S': ['Stressed', 'S'],
    'LLRZ': ['LLRZ Lab Stress', 'LLRZ'],
	#'D': ['Dormant', 'D'],
    'R': ['Rust', 'R'],
    'H': ['Healthy', 'H']
}

lifecycle_codes = { 
	'D': ['Dormant', 'D'],
    'RE': ['Re-emergence', 'RE'],
    'FLG': ['Flowering', 'FLG'],
    'FRG': ['Fruiting', 'FRG'],
    "FFG": ['Fruiting and Flowering', 'FFG'],
    'N': ['Neither', 'N']
}

# data lists

d_spectra = []
d_plant = []
d_part = []
d_health = []
d_age = []
d_lifecycle = []

yd_all_dict = {
    'plant': d_plant,
    'age': d_age,
    'part': d_part,
    'health': d_health,
    'lifecycle': d_lifecycle
}

code_category_dict = {
    'plant': plant_codes,
    'age': age_codes,
    'part': principal_part_codes,
    'health': health_codes,
    'lifecycle': lifecycle_codes
}

In [None]:
## Find ROI data

roi_files = find_roi_files('data/pkl/rois_labeled')

print(f"Number of ROI files found: {len(roi_files)}")

In [None]:
## Prepare ROI data

# ==============================================================================
# Helper Function for Stratified Sampling
# ==============================================================================
def stratified_sample_with_min_per_roi(spectra_df, min_per_roi=50, total_samples=300):
    """
    Selects a subset of labeled ROI pixels from a DataFrame.

    Args:
        spectra_df (pd.DataFrame): DataFrame containing spectral data and a 'roi_name' column.
        min_per_roi (int): The minimum number of pixels to select from each ROI.
        total_samples (int): The total number of pixels to select.

    Returns:
        pd.DataFrame: A DataFrame containing the selected subset of ROI pixels.
    """
    # Group by ROI and sample a minimum number of pixels from each.
    # If an ROI has fewer pixels than min_per_roi, all its pixels are taken.
    guaranteed_samples = spectra_df.groupby('roi_name').apply(
        lambda x: x.sample(n=min(len(x), min_per_roi))
    ).reset_index(drop=True)

    # If a total sample size is specified and it's larger than the guaranteed sample
    if total_samples and total_samples > len(guaranteed_samples):
        remaining_to_select = total_samples - len(guaranteed_samples)
        
        # Create a pool of remaining pixels by excluding those already selected
        # We use the DataFrame index to identify unique rows
        remaining_pixels_df = spectra_df.drop(guaranteed_samples.index)

        # If there are enough remaining pixels, sample from them
        if remaining_to_select > 0 and not remaining_pixels_df.empty:
            num_to_sample_from_remaining = min(remaining_to_select, len(remaining_pixels_df))
            additional_samples = remaining_pixels_df.sample(n=num_to_sample_from_remaining)
            
            # Combine the guaranteed and additional samples
            final_selection_df = pd.concat([guaranteed_samples, additional_samples])
        else:
            final_selection_df = guaranteed_samples
    else:
        final_selection_df = guaranteed_samples

    return final_selection_df


# ==============================================================================
# Main Data Processing and Sampling Logic
# ==============================================================================

# --- STAGE 1: Aggregate all ROI data from all files ---

# This list will hold DataFrames of processed ROI data from each file
all_rois_data_list = []
resample_rois = True

# Assuming roi_files is a list of your file paths
for roi_filename in roi_files:
    # Unpickling the dictionary
    with open(roi_filename, 'rb') as f:
        roiData = pickle.load(f)
        roi_df = roiData.df  # a DataFrame holding all the data for the ROI

    roi_spectra = roi_df.to_numpy()[:, 4:].astype(np.float32)
    roi_spectra_names = roi_df['Name'].to_numpy() # These are the labels for each pixel

    if resample_rois:
        wl_roi = roi_df.columns.to_numpy()[4:]
        resampler_roi = spectral.BandResampler(wl_roi, wl_img)
        print("Number of bands: ", len(wl_roi))
        roi_spectra_resampled = resampler_roi(roi_spectra.T).T
        roi_spectra = roi_spectra_resampled

    if "crisfield" in roi_filename.lower():
        print("Running pixel-wise normalization for 'crisfield'")
        min_vals_pixel = np.min(roi_spectra, axis=1, keepdims=True)
        max_vals_pixel = np.max(roi_spectra, axis=1, keepdims=True)
        range_vals_pixel = max_vals_pixel - min_vals_pixel
        range_vals_pixel[range_vals_pixel == 0] = 1
        roi_spectra = (roi_spectra - min_vals_pixel) / range_vals_pixel

    # Create a temporary DataFrame for the current file's data
    temp_df = pd.DataFrame(roi_spectra)
    temp_df['roi_name'] = roi_spectra_names
    
    # Add the processed data to our master list
    all_rois_data_list.append(temp_df)
    
    print(f"Processed and aggregated data from {roi_filename}")

# Concatenate all data into a single master DataFrame
master_roi_df = pd.concat(all_rois_data_list, ignore_index=True)

print("\n" + "="*50)
print("ROI Data Aggregation Complete.")
print(f"Total number of ROI pixels collected: {len(master_roi_df)}")
print("Pixel distribution before sampling:")
print(master_roi_df['roi_name'].value_counts().to_string())
print("="*50 + "\n")


# --- STAGE 2: Perform Stratified Sampling on the aggregated data ---

# Define your sampling parameters
MIN_PIXELS_PER_ROI = 50 
TOTAL_ROI_PIXELS_TO_SELECT = 300

print(f"Starting stratified sampling with min {MIN_PIXELS_PER_ROI}/ROI and a target total of {TOTAL_ROI_PIXELS_TO_SELECT}...")

selected_rois_df = stratified_sample_with_min_per_roi(
    master_roi_df,
    min_per_roi=MIN_PIXELS_PER_ROI,
    total_samples=TOTAL_ROI_PIXELS_TO_SELECT
)

print("\n" + "="*50)
print("Stratified Sampling Complete.")
print(f"Total number of pixels selected: {len(selected_rois_df)}")
print("Pixel distribution after sampling:")
print(selected_rois_df['roi_name'].value_counts().to_string())
print("="*50 + "\n")


# --- STAGE 3: Populate final data structures with the selected subset ---

# Extract the selected spectra and names for final processing
selected_spectra = selected_rois_df.drop(columns=['roi_name']).to_numpy()
selected_names = selected_rois_df['roi_name'].to_numpy()

for i in range(len(selected_spectra)):
    roi_spectrum = selected_spectra[i]
    name = selected_names[i]

    # Append the selected spectrum
    d_spectra.append(roi_spectrum)
    
    # --- This block is your original metadata parsing logic ---
    if name == 'Genus_spe_N_N_N_N':
        new_name = 'Genus_spe_MX_N_N_N'
        print(f"Renaming {name} to {new_name}")
        name = new_name

    if name[-1] != '_':
        name = name + '_'
    
    class_data_dict = {}
    
    for cat, codes in code_category_dict.items():
        class_data_dict[cat] = 'N'
        for key, value in codes.items():
            if cat == 'plant':
                if (name[:8].lower() == key.lower()) or (name[:9].lower() == key.lower()):
                    class_data_dict[cat] = value[0] + '_' + value[1]
            else:
                if '_' + key + '_' in name:
                    class_data_dict[cat] = value[1]
    
    # Append the parsed metadata for the selected spectrum
    for key in yd_all_dict:
        yd_all_dict[key].append(class_data_dict[key])

print("Final `d_spectra` and `yd_all_dict` have been populated with the sampled data.")

d_spectra = np.asarray(d_spectra)
print(d_spectra.shape)

for key in yd_all_dict:
    yd_all_dict[key] = np.asarray(yd_all_dict[key])
    print(key, yd_all_dict[key].shape)

## Add selected ROI spectra to library spectra

In [None]:
# num_to_select = 250  # Here (250 or 300) randomly selected ROI spectra are added to the full spectral library

# total_indices = d_spectra.shape[0]

# selected_indices = np.random.choice(total_indices, size=num_to_select, replace=False)

# selected_d_spectra = d_spectra[selected_indices]
# print("\nNew d_spectra shape:", selected_d_spectra.shape)

# selected_yd_all_dict = {}
# for key in yd_all_dict:
#     # Select the corresponding rows/elements from the original array
#     selected_yd_all_dict[key] = yd_all_dict[key][selected_indices]
#     print(f"New {key} shape: {selected_yd_all_dict[key].shape}")

selected_d_spectra = d_spectra
selected_yd_all_dict = yd_all_dict

In [None]:
spectra = np.concatenate((spectra, selected_d_spectra))

sc.name = np.concatenate((sc.name, selected_yd_all_dict['plant']))
sc.age = np.concatenate((sc.age, selected_yd_all_dict['age']))
sc.principle_part = np.concatenate((sc.principle_part, selected_yd_all_dict['part']))
sc.health = np.concatenate((sc.health, selected_yd_all_dict['health']))
sc.lifecycle = np.concatenate((sc.lifecycle, selected_yd_all_dict['lifecycle']))
print(f"New spectra shape: {spectra.shape}")
print(f"New sc.name shape: {sc.name.shape}")
print(f"New sc.age shape: {sc.age.shape}")
print(f"New sc.principle_part shape: {sc.principle_part.shape}")
print(f"New sc.health shape: {sc.health.shape}")
print(f"New sc.lifecycle shape: {sc.lifecycle.shape}")

## Prepare data for training and testing

In [None]:
def assign_integer_labels(data_array, label_array):
    """
    Assigns integer values from a label array to corresponding values in a data array.

    Args:
        data_array:  The array containing the string or categorical data (e.g., Yn_o, Ypp_o, etc.).
        label_array: The array of unique labels (e.g., y_plant_labels, y_part_labels, etc.).

    Returns:
        A NumPy array containing the integer representations of the data based on the labels.
    """
    mapping = {label: i for i, label in enumerate(label_array)}
    integer_array = np.array([mapping[x] for x in data_array])
    return integer_array


X_all = spectra

y_plant_labels = np.unique(sc.name)
y_part_labels = np.unique(sc.principle_part)
y_health_labels = np.unique(sc.health)
y_age_labels = np.unique(sc.age)
y_lifecycle_labels = np.unique(sc.lifecycle)

label_maps = {'plant': y_plant_labels, 'age': y_age_labels, 'part': y_part_labels, 'health': y_health_labels, 'lifecycle': y_lifecycle_labels}

# Store the label_maps dictionary in a shelve key-value store
with shelve.open('data/shelve/label_maps_store') as db:
    db['label_maps'] = label_maps

# Number of classes for each task
n_plant_classes = len(y_plant_labels)
n_age_classes = len(y_age_labels)
n_part_classes = len(y_part_labels)
n_health_classes = len(y_health_labels)
n_lifecycle_classes = len(y_lifecycle_labels)

Yn_int = assign_integer_labels(sc.name, y_plant_labels)
Yp_int = assign_integer_labels(sc.principle_part, y_part_labels)
Yh_int = assign_integer_labels(sc.health, y_health_labels)
Ya_int = assign_integer_labels(sc.age, y_age_labels)
Yl_int = assign_integer_labels(sc.lifecycle, y_lifecycle_labels)

def replace_n_with_ignore_val(integer_array, label_array):
    """
    Replaces integer values with -1 in an integer array where the corresponding
    label in the label array is 'N'.

    Args:
        integer_array: The integer-encoded data array (e.g., Yn_int, Ypp_int, etc.).
        label_array: The array of unique labels (e.g., y_plant_labels, y_part_labels, etc.).

    Returns:
        A NumPy array with values replaced by -1 where the label is 'N'.
    """
    try:
        n_index = np.where(label_array == 'N')[0][0]  # Find the index of 'N'
        integer_array[integer_array == n_index] = IGNORE_VALUE  # Replace values
        return integer_array
    except IndexError:
        print("'N' not found in the label array.  No replacements made.")
        return integer_array


y_plant = replace_n_with_ignore_val(Yn_int, y_plant_labels)
y_part = replace_n_with_ignore_val(Yp_int, y_part_labels)
y_age = replace_n_with_ignore_val(Ya_int, y_age_labels)
y_health = replace_n_with_ignore_val(Yh_int, y_health_labels)

y_lifecycle = Yl_int

y_all_dict_original = {'plant': y_plant, 'part': y_part, 'age': y_age, 'health': y_health, 'lifecycle': y_lifecycle}

In [None]:
# --- Data Splitting (Train/Validation/Test on features and ORIGINAL labels) ---
indices = np.arange(len(X_all))
train_indices, test_indices = train_test_split(indices, test_size=0.15, random_state=42)
train_indices, val_indices = train_test_split(train_indices, test_size=0.1765, random_state=42)
X_train, X_val, X_test = X_all[train_indices], X_all[val_indices], X_all[test_indices]
y_train_dict_orig = {task: y_all_dict_original[task][train_indices] for task in TASK_NAMES}
y_val_dict_orig = {task: y_all_dict_original[task][val_indices] for task in TASK_NAMES}
y_test_dict_orig = {task: y_all_dict_original[task][test_indices] for task in TASK_NAMES}
print(f"Data split sizes: Train={len(X_train)}, Val={len(X_val)}, Test={len(X_test)}")

test_indices_library = test_indices[test_indices < len(sc.spectra)]
print("Number of library spectra in test set: ", len(test_indices_library))
test_indices_rois = test_indices[test_indices >= len(sc.spectra)]
print("Number of ROI spectra in test set: ", len(test_indices_rois))

X_test_library = X_all[test_indices_library]
X_test_rois = X_all[test_indices_rois]

y_test_dict_library = {task: y_all_dict_original[task][test_indices_library] for task in TASK_NAMES}
y_test_dict_rois = {task: y_all_dict_original[task][test_indices_rois] for task in TASK_NAMES}

In [None]:
# --- Preprocessing (Standardization) ---
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)

scaler_filename = SCALER_FILEPATH
with open(scaler_filename, 'wb') as file:
    pickle.dump(scaler, file)

print(f"Scaler saved to {scaler_filename}")

X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# --- Reshape for Conv1D: (batch, steps, channels=1) ---
X_train_scaled = X_train_scaled[..., np.newaxis]
X_val_scaled = X_val_scaled[..., np.newaxis]
X_test_scaled = X_test_scaled[..., np.newaxis]
#X_train_scaled = X_train[..., np.newaxis]
#X_val_scaled = X_val[..., np.newaxis]
#X_test_scaled = X_test[..., np.newaxis]

print(f"Feature shapes: Train={X_train_scaled.shape}, Val={X_val_scaled.shape}, Test={X_test_scaled.shape}")

# --- Prepare Labels for Keras Fit/Evaluate (Replace IGNORE_VALUE with 0) ---
y_train_dict_keras = {f"{task}_output": np.maximum(0, y_train_dict_orig[task]) for task in TASK_NAMES}
y_val_dict_keras = {f"{task}_output": np.maximum(0, y_val_dict_orig[task]) for task in TASK_NAMES}
y_test_dict_keras = {f"{task}_output": np.maximum(0, y_test_dict_orig[task]) for task in TASK_NAMES}
print(f"Example Keras 'age_output' labels for training: {y_train_dict_keras['age_output'][:20]}")

# --- Create Sample Weight Dictionaries ---
sample_weights_train = {f"{task}_output": (y_train_dict_orig[task] != IGNORE_VALUE).astype(np.float32) for task in TASK_NAMES}
sample_weights_val = {f"{task}_output": (y_val_dict_orig[task] != IGNORE_VALUE).astype(np.float32) for task in TASK_NAMES}
sample_weights_test = {f"{task}_output": (y_test_dict_orig[task] != IGNORE_VALUE).astype(np.float32) for task in TASK_NAMES}
print(f"Example train sample weights for 'age_output': {sample_weights_train['age_output'][:20]}")


X_test_library_scaled = scaler.transform(X_test_library)
X_test_library_scaled = X_test_library_scaled[..., np.newaxis]

X_test_rois_scaled = scaler.transform(X_test_rois)
X_test_rois_scaled = X_test_rois_scaled[..., np.newaxis]

y_test_dict_library_keras = {f"{task}_output": np.maximum(0, y_test_dict_library[task]) for task in TASK_NAMES}
y_test_dict_rois_keras = {f"{task}_output": np.maximum(0, y_test_dict_rois[task]) for task in TASK_NAMES}

sample_weights_test_library = {f"{task}_output": (y_test_dict_library[task] != IGNORE_VALUE).astype(np.float32) for task in TASK_NAMES}
sample_weights_test_rois = {f"{task}_output": (y_test_dict_rois[task] != IGNORE_VALUE).astype(np.float32) for task in TASK_NAMES}

## Build NN

In [None]:
# Building a 1D CNN to capture local patterns (like absorption peaks or slopes across adjacent bands)

def build_spectral_cnn(input_shape, n_plant, n_age, n_part, n_health, n_lifecycle):
    inputs = keras.Input(shape=input_shape, name='spectrum_input')

    # --- Shared Feature Extractor (1D CNN Backbone) ---
    x = layers.Conv1D(filters=32, kernel_size=7, activation='relu', padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=3)(x)
    x = layers.Dropout(0.25)(x)

    x = layers.Conv1D(filters=64, kernel_size=5, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=3)(x)
    x = layers.Dropout(0.25)(x)

    x = layers.Conv1D(filters=128, kernel_size=3, activation='relu', padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=3)(x)
    x = layers.Dropout(0.3)(x)

    shared_features = layers.Flatten()(x)

    # Optional shared dense layer
    shared_features = layers.Dense(128, activation='relu')(shared_features)
    shared_features = layers.BatchNormalization()(shared_features)
    shared_features = layers.Dropout(0.5)(shared_features)

    # --- Output Heads (One per task) ---
    plant_output = layers.Dense(64, activation='relu')(shared_features)
    plant_output = layers.Dense(n_plant, activation='softmax', name='plant_output')(plant_output)

    age_output = layers.Dense(32, activation='relu')(shared_features)
    age_output = layers.Dense(n_age, activation='softmax', name='age_output')(age_output)

    part_output = layers.Dense(32, activation='relu')(shared_features)
    part_output = layers.Dense(n_part, activation='softmax', name='part_output')(part_output)

    health_output = layers.Dense(32, activation='relu')(shared_features)
    health_output = layers.Dense(n_health, activation='softmax', name='health_output')(health_output)

    lifecycle_output = layers.Dense(32, activation='relu')(shared_features)
    lifecycle_output = layers.Dense(n_lifecycle, activation='softmax', name='lifecycle_output')(lifecycle_output)

    # --- Build the Model ---
    model = keras.Model(
        inputs=inputs,
        outputs={
            'plant_output': plant_output,
            'age_output': age_output,
            'part_output': part_output,
            'health_output': health_output,
            'lifecycle_output': lifecycle_output
        },
        name="spectral_multi_task_cnn"
    )
    return model

input_shape = X_train_scaled.shape[1:]
model = build_spectral_cnn(input_shape, n_plant_classes, n_age_classes, n_part_classes, n_health_classes, n_lifecycle_classes)
model.summary()

### Compile model

In [None]:
optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)
losses = {name: 'sparse_categorical_crossentropy' for name in OUTPUT_NAMES}
metrics = {name: 'sparse_categorical_accuracy' for name in OUTPUT_NAMES}

model.compile(
    optimizer=optimizer,
    loss=losses,
    loss_weights=LOSS_WEIGHTS,
    metrics=metrics, # Standard metrics
    weighted_metrics=metrics # Weighted metrics
)

print("\nModel Compiled.")
print(f"Losses: {model.loss}")
print(f"Metrics: {model.metrics_names}")

## Train NN

In [None]:
# Train NN if weights are not loaded in the previous step

# Callbacks
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=PATIENCE,
    restore_best_weights=True,
    verbose=1
)

print("\nStarting Training...")
history = model.fit(
    X_train_scaled,
    y_train_dict_keras,           # Labels dictionary (keys match output names)
    sample_weight=sample_weights_train, # Sample weights dictionary (keys match output names)
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=(X_val_scaled, y_val_dict_keras, sample_weights_val), # Also pass val weights
    callbacks=[early_stopping],
    verbose=1
)

print("\nTraining Finished.")

model.save_weights(CHECKPOINT_FILEPATH)
model.save(MODEL_FILEPATH)

## Evaluate NN

In [None]:
label_maps = {
    'plant': y_plant_labels,
    'age': y_age_labels,
    'part': y_part_labels,
    'health': y_health_labels,
    'lifecycle': y_lifecycle_labels
}

In [None]:
import warnings

# Ignore all warnings
warnings.filterwarnings('ignore')

print("\nEvaluating on Test Set with Best Model (using sample weights)...")

results = model.evaluate(
    X_test_scaled,
    y_test_dict_keras,
    sample_weight=sample_weights_test,
    batch_size=BATCH_SIZE,
    verbose=0,
    return_dict=True
)

print("\nTest Set Evaluation Results:")
print(f"Overall Loss (Weighted Sum): {results['loss']:.4f}")

print("\nTest Weighted Metrics (Accuracy ignoring invalid samples):")
for name in OUTPUT_NAMES:
    metric_key = f"{name}_weighted_sparse_categorical_accuracy"
    if metric_key in results:
        print(f"  {name.replace('_output', '').capitalize()}: {results[metric_key]:.4f}")
    else:
         metric_key_alt = f"weighted_{name}_sparse_categorical_accuracy"
         if metric_key_alt in results:
              print(f"  {name.replace('_output', '').capitalize()} (alt key): {results[metric_key_alt]:.4f}")
         else:
              print(f"  {name.replace('_output', '').capitalize()}: Weighted metric key not found in results.")


In [None]:
# Evaluate on Test Set from Spectral Library only

import warnings

# Ignore all warnings
warnings.filterwarnings('ignore')

print("\nEvaluating on Test Set (Library) with Best Model (using sample weights)...")

results = model.evaluate(
    X_test_library_scaled,
    y_test_dict_library_keras,
    sample_weight=sample_weights_test_library,
    batch_size=BATCH_SIZE,
    verbose=0,
    return_dict=True
)

print("\nTest Set Evaluation Results:")
print(f"Overall Loss (Weighted Sum): {results['loss']:.4f}")

print("\nTest Weighted Metrics (Accuracy ignoring invalid samples):")
for name in OUTPUT_NAMES:
    metric_key = f"{name}_weighted_sparse_categorical_accuracy"
    if metric_key in results:
        print(f"  {name.replace('_output', '').capitalize()}: {results[metric_key]:.4f}")
    else:
         metric_key_alt = f"weighted_{name}_sparse_categorical_accuracy"
         if metric_key_alt in results:
              print(f"  {name.replace('_output', '').capitalize()} (alt key): {results[metric_key_alt]:.4f}")
         else:
              print(f"  {name.replace('_output', '').capitalize()}: Weighted metric key not found in results.")

In [None]:
# Evaluate on Test Set from ROIs only

import warnings

# Ignore all warnings
warnings.filterwarnings('ignore')

print("\nEvaluating on Test Set (ROIs) with Best Model (using sample weights)...")

results = model.evaluate(
    X_test_rois_scaled,
    y_test_dict_rois_keras,
    sample_weight=sample_weights_test_rois,
    batch_size=BATCH_SIZE,
    verbose=0,
    return_dict=True
)

print("\nTest Set Evaluation Results:")
print(f"Overall Loss (Weighted Sum): {results['loss']:.4f}")

print("\nTest Weighted Metrics (Accuracy ignoring invalid samples):")
for name in OUTPUT_NAMES:
    metric_key = f"{name}_weighted_sparse_categorical_accuracy"
    if metric_key in results:
        print(f"  {name.replace('_output', '').capitalize()}: {results[metric_key]:.4f}")
    else:
         metric_key_alt = f"weighted_{name}_sparse_categorical_accuracy"
         if metric_key_alt in results:
              print(f"  {name.replace('_output', '').capitalize()} (alt key): {results[metric_key_alt]:.4f}")
         else:
              print(f"  {name.replace('_output', '').capitalize()}: Weighted metric key not found in results.")

In [None]:
# Optional: Print Labeled Codes and Counts in Test Set for Each Species (checking for a balanced test set)

from collections import Counter

import warnings

# Ignore all warnings
warnings.filterwarnings('ignore')

for i, val in enumerate(y_plant_labels):
    
    selected_indices = np.arange(len(y_test_dict_orig['plant']))
    selected_indices = selected_indices[y_test_dict_orig['plant'] == i]

    y_test_dict_keras_filtered = {key : y_test_dict_keras[key][selected_indices] for key in y_test_dict_keras.keys()}
    sample_weights_test_filtered = {key : sample_weights_test[key][selected_indices] for key in sample_weights_test.keys()}

    #print labels for values for all tasks of y_test_dict_keras['plant_output'][selected_indices]

    #counts_dict = dict(zip(np.unique(y_test_dict_keras_filtered, return_counts=True)))

    value_counts = {}

    for key, arr in y_test_dict_keras_filtered.items():
        value_counts[key] = Counter(arr)

    print("Labeled Codes and Counts in Test Set for Each Species")

    for key, counts in value_counts.items():

        key = key.replace('_output', '')
        
        print(f" '{key}':")
        for value, count in counts.items():
            print(f"    {label_maps[key][value]}: {count} {'spectra' if key == 'plant' else ''}")

    results = model.evaluate(
        X_test_scaled[selected_indices],
        y_test_dict_keras_filtered,
        sample_weight=sample_weights_test_filtered,
        batch_size=BATCH_SIZE,
        verbose=0,
        return_dict=True
    )

    print(f"\nTest Set Evaluation Results for plant={val}:")
    print(f"Overall Loss (Weighted Sum): {results['loss']:.4f}")

    print("\nTest Weighted Metrics (Accuracy ignoring invalid samples):")
    for name in OUTPUT_NAMES:
        metric_key = f"{name}_weighted_sparse_categorical_accuracy"
        if metric_key in results:
            print(f"  {name.replace('_output', '').capitalize()}: {results[metric_key]:.4f}")
        else:
            metric_key_alt = f"weighted_{name}_sparse_categorical_accuracy"
            if metric_key_alt in results:
                print(f"  {name.replace('_output', '').capitalize()} (alt key): {results[metric_key_alt]:.4f}")
            else:
                print(f"  {name.replace('_output', '').capitalize()}: Weighted metric key not found in results.")
    print("\n____________________________________________________________\n")


In [None]:
# Optional: Show confusion matrices for each task

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np # Ensure numpy is imported

# --- Predictions on the Test Set ---
print("\nGenerating predictions on the test set...")
predictions = model.predict(X_test_scaled, batch_size=BATCH_SIZE)
print("Predictions generated.")

# Ensure predictions is a dictionary if it's not already (Keras usually returns dict for multi-output)
if not isinstance(predictions, dict):
     # If the model output is a list, convert it back to a dict based on output names
     output_layer_names = model.output_names # Or use OUTPUT_NAMES if they match exactly
     predictions = dict(zip(output_layer_names, predictions))


# # --- Create a mapping from task name to its labels ---
# # Make sure these label arrays are accessible here
# label_maps = {
#     'plant': y_plant_labels,
#     'age': y_age_labels,
#     'part': y_part_labels,
#     'health': y_health_labels,
#     'lifecycle': y_lifecycle_labels
# }

# --- Generate and Plot Confusion Matrices ---
print("\nGenerating Confusion Matrices and Classification Reports...")

for task in TASK_NAMES:
    output_name = f"{task}_output" # e.g., 'plant_output'

    print(f"\n--- Task: {task.capitalize()} ---")

    # 1. Get True Labels (Original, includes integer mapping for 'N')
    y_true_all = y_test_dict_orig[task]

    # 2. Get Predicted Labels (Integers)
    if output_name not in predictions:
        print(f"Warning: Output key '{output_name}' not found in model predictions. Skipping task '{task}'.")
        continue
    y_pred_probs = predictions[output_name]
    y_pred_all = np.argmax(y_pred_probs, axis=-1) # Get the class index with the highest probability

    # 3. Get Sample Weights
    weights = sample_weights_test[output_name]

    # 4. Get the corresponding class names (labels)
    original_labels = label_maps[task] # e.g., y_lifecycle_labels

    # 5. *Conditional* Filtering and Label Definition
    if task == 'lifecycle':
        # --- Lifecycle Task: Include 'N' ---
        # We don't filter samples based on weight, as weight=0 specifically marks 'N' here
        # We want to see how 'N' is classified.
        y_true_valid = y_true_all
        y_pred_valid = y_pred_all

        # Define the labels for the CM to include *all* original categories
        # Note: y_true_valid might contain IGNORE_VALUE (-1) if 'N' was mapped to it.
        # The confusion_matrix function handles this gracefully if -1 isn't in `labels`.
        # So, we define `valid_indices` based on the original mapping.
        valid_indices = list(range(len(original_labels))) # 0, 1, 2,... N_classes-1
        valid_string_labels = list(original_labels) # Includes 'N' string

        print(f"Including 'N' category for {task.capitalize()}. Evaluating on {len(y_true_valid)} samples.")

    else:
        # --- Other Tasks: Exclude 'N' ---
        # Filter out ignored samples using the weights
        valid_mask = (weights == 1.0) # Mask for samples that should NOT be ignored
        y_true_valid = y_true_all[valid_mask]
        y_pred_valid = y_pred_all[valid_mask]

        # Define the labels for the CM *excluding* the 'N' category
        try:
            # Find the integer index originally assigned to 'N'
            n_index = np.where(original_labels == 'N')[0][0]
            valid_indices = [i for i in range(len(original_labels)) if i != n_index]
            valid_string_labels = [label for i, label in enumerate(original_labels) if i != n_index]
        except IndexError:
             # 'N' category doesn't exist for this task, include all labels
             print(f"Note: 'N' category not found in labels for task '{task}'. Including all labels.")
             valid_indices = list(range(len(original_labels)))
             valid_string_labels = list(original_labels)

        print(f"Excluding 'N' category for {task.capitalize()}. Evaluating on {len(y_true_valid)} samples.")


    # 6. Check if there are any valid samples left for this task
    if len(y_true_valid) == 0:
        print(f"No samples to evaluate for task '{task}' after filtering (if applicable). Skipping CM.")
        continue

    # 7. Determine the final set of labels/indices *present* in the data for the CM
    # This is important because even if we define all labels, some might not
    # appear in the specific y_true_valid/y_pred_valid subset.
    present_true_labels = np.unique(y_true_valid)
    present_pred_labels = np.unique(y_pred_valid)

    # Combine present labels, but only keep those that were in our initial 'valid_indices' list
    # (This prevents including the IGNORE_VALUE index (-1) if 'N' was mapped to it for lifecycle)
    all_present_indices_in_data = np.unique(np.concatenate((present_true_labels, present_pred_labels)))
    final_cm_indices = [idx for idx in valid_indices if idx in all_present_indices_in_data]
    final_cm_labels = [label for idx, label in zip(valid_indices, valid_string_labels) if idx in final_cm_indices]


    # Handle case where after filtering, no valid classes remain
    if not final_cm_indices:
         print(f"No valid classes found in true/predicted labels for task '{task}' for CM. Skipping CM.")
         continue

    # 8. Calculate Confusion Matrix using the *final* determined indices
    # `labels=final_cm_indices` ensures the matrix axes match `final_cm_labels`
    cm = confusion_matrix(y_true_valid, y_pred_valid, labels=final_cm_indices)

    # 9. Plot Confusion Matrix
    plt.figure(figsize=(max(8, len(final_cm_labels)*0.8), max(6, len(final_cm_labels)*0.6))) # Adjust size based on num labels
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=final_cm_labels,
                yticklabels=final_cm_labels)
    plt.title(f'Confusion Matrix - Task: {task.capitalize()}')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

    # 10. Print Classification Report
    print("\nClassification Report:")
    try:
        # Use zero_division=0 or 1 for handling classes with no true samples
        report = classification_report(y_true_valid, y_pred_valid,
                                   labels=final_cm_indices,
                                   target_names=final_cm_labels,
                                   zero_division=0)
        print(report)
    except ValueError as e:
        print(f"Could not generate classification report for {task}: {e}")
        print("This might happen if predicted values contain labels not present in true values after filtering.")
        print(f"Unique True values considered: {np.unique(y_true_valid)}")
        print(f"Unique Pred values considered: {np.unique(y_pred_valid)}")
        print(f"Indices used for report: {final_cm_indices}")
        print(f"Labels used for report: {final_cm_labels}")


print("\nFinished generating visualizations.")