## Imports and initialize TensorFlow

In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from IPython.display import display
import pandas as pd
import spectral
import seaborn as sns

import shelve
import pickle
import os

import matplotlib.pyplot as plt

from IPython.display import display

# ==============================================================================
# 1. CONFIGURATION AND SETUP
# ==============================================================================

# --- Constants for Model and Training ---
TASK_NAMES = ['plant', 'age', 'part', 'health', 'lifecycle']
OUTPUT_NAMES = [f"{task}_output" for task in TASK_NAMES]
IGNORE_VALUE = -1

# --- Hyperparameters ---
BATCH_SIZE = 32
PRETRAIN_EPOCHS = 800         # 400 # Max epochs for pre-training (EarlyStopping will find the best)
FINETUNE_EPOCHS = 200         # 200 # Max epochs for fine-tuning
PATIENCE = 30                 # Patience for Early Stopping
LEARNING_RATE = 1e-4          # Pre-training learning rate
FT_LEARNING_RATE = 1e-4       # Fine-tuning learning rate

# --- Configuration Flags ---
EXPORT_TO_CSV = True          # Flag to save evaluation reports as CSV files

# --- Loss weights (for weighting the contribution of each task to the total loss) ---
LOSS_WEIGHTS = {
    'plant_output': 1.0,
    'age_output': 1.0,
    'part_output': 1.0,
    'health_output': 1.0,
    'lifecycle_output': 1.0
}

# --- Directory Setup ---
os.makedirs('models', exist_ok=True)
os.makedirs('reports', exist_ok=True)
os.makedirs('scalers', exist_ok=True)

# --- Suppress TensorFlow and other warnings for cleaner output ---
warnings.filterwarnings('ignore')
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


print(f"Using TensorFlow version: {tf.__version__}")
# Optional: Configure GPU memory growth if needed
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"Configured memory growth for {len(gpus)} GPU(s)")
    except RuntimeError as e:
        print(e)

## Load full spectral library

In [None]:
import sys
sys.path.append('util/')
import importlib

import util_scripts as util

import os
from dotenv import load_dotenv

from pymongo.mongo_client import MongoClient
from pymongo.server_api import ServerApi

_ = load_dotenv()

MONGO_DBR_URI = os.getenv('MONGO_DBR_URI')

reload_data_driver = False

if (reload_data_driver):

    records = []

    # Create a new client and connect to the server
    client = MongoClient(MONGO_DBR_URI, server_api=ServerApi('1'))

    # Send a ping to confirm a successful connection
    try:
        client.admin.command('ping')
        print("Pinged your deployment. You successfully connected to MongoDB!")
    except Exception as e:
        print(e)

    db = client["upwins_db"]
    view_name = "spectral_library"
    spectral_library = db[view_name]

    records = spectral_library.find()

    df = pd.DataFrame(records)
    df.to_pickle('data/pkl/library_with_Genus_species.pkl')

else:
    df = pd.read_pickle('data/pkl/library_with_Genus_species.pkl')
    #df = pd.read_pickle('data/pkl/library.pkl')

sc = util.SpectralCollection(df)
wl_lib = sc.wl
name = sc.name
spectra = sc.spectra
print(sc.spectra.shape)

plant_array_lib = sc.name
age_array_lib = sc.age
part_array_lib = sc.principle_part
health_array_lib = sc.health
lifecycle_array_lib = sc.lifecycle

## Option 1: Get target bands for resampling from Imagery

In [None]:
#fname = 'data/morven_4000/raw_4000_or_ref.img'
#fname_hdr = 'data/morven_4000/raw_4000_or_ref.hdr'

fname = 'data/morven_9-2025/raw_36286_or_ref.img'
fname_hdr = 'data/morven_9-2025/raw_36286_or_ref.hdr'

#fname = 'data/5-8-2025/100133_Allied_05_08_2025_2015_06_04_17_50_15/raw_0_ref'
#fname_hdr = 'data/5-8-2025/100133_Allied_05_08_2025_2015_06_04_17_50_15/raw_0_ref.hdr'

# Open the image and read into an array
im = spectral.envi.open(fname_hdr, fname)
wl_img = np.asarray(im.bands.centers)

### Optional: Load the image into memory

In [None]:
# OPTIONAL: Load the image into memory
im.Arr = im.load()
print(f'Shape of Im.Arr = {im.Arr.shape}')
im.List = np.reshape(im.Arr, (im.nrows*im.ncols, im.nbands))
print(f'Shape of im.List = {im.List.shape}')

valid_pixel_mask = np.sum(im.List, axis=1)>0

dataList = im.List[valid_pixel_mask, :]
print(f'Shape of dataList = {dataList.shape}')
nr = im.nrows
nc = im.ncols
nb = im.nbands

## Option 2: Get target bands for resampling from ROI

In [None]:
roi_filepath = 'data/pkl/rois_labeled/crisfield/Crisfield_October_Training_ROIs_Img14_0513_Iva_fru.pkl'

with open(roi_filepath, 'rb') as f:
    roiData = pickle.load(f)
    roi_df = roiData.df # a DataFrame holding all the data for the ROI

roi_spectra = roi_df.to_numpy()[:,4:]
roi_spectra = roi_spectra.astype(np.float32)

wl_roi = roi_df.columns.to_numpy()[4:]

print("ROI target bands count: ", len(wl_roi))

# set wl_img = to wl_roi for use in the rest of the code
wl_img = wl_roi

## Resample library to match target bands

In [None]:
# Create a BandResampler, which is a function that resamples spectra from one source to match a different source.
# See: https://www.spectralpython.net/class_func_ref.html?highlight=resampling#spectral.algorithms.resampling.BandResampler
# Inputs: the first input is the wavelengths for the spectra that you are going to resample, the second input is the wavelengths that you want to resample to.

resampler = spectral.BandResampler(wl_lib, wl_img)
spectra_resampled = resampler(spectra.T).T

print(f'The shape of the resampled spectral library is {spectra_resampled.shape}.')
print(f'({spectra_resampled.shape[0]} spectra with {spectra_resampled.shape[1]} bands.)')

## Prepare ROI data for training

In [None]:
import pickle
import os

def find_roi_files(root_dir):
    string_list = ['.pkl']
    
    matching_files = []
    for root, _, files in os.walk(root_dir):
        for filename in files:
            if all(string in filename.lower() for string in string_list):
                matching_files.append(os.path.join(root, filename))
    return matching_files

# Project codes for labeling ROI data
# **IMPORTANT**: ROIs should be named using the **same** naming convention used to label ASD files 

plant_codes = {
    'Ammo_bre': ['Ammophila', 'breviligulata', 'American Beachgrass', 'grass', 'https://en.wikipedia.org/wiki/Ammophila_breviligulata'],
    'Chas_lat': ['Chasmanthium', 'latifolium', 'River Oats', 'grass', 'https://en.wikipedia.org/wiki/Chasmanthium_latifolium'],
    'Pani_ama': ['Panicum', 'amarum', 'Coastal Panic Grass', 'grass', 'https://en.wikipedia.org/wiki/Panicum_amarum'],
    'Pani_vir': ['Panicum', 'virgatum', 'Switch Grass', 'grass', 'https://en.wikipedia.org/wiki/Panicum_virgatum'],
    'Soli_sem': ['Solidago', 'sempervirens', 'Seaside Goldenrod', 'succulent', 'https://en.wikipedia.org/wiki/Chasmanthium_latifolium'],
    'Robi_his': ['Robinia', 'hispida', 'Bristly locust', 'shrub', 'https://en.wikipedia.org/wiki/Robinia_hispida'],
    'More_pen': ['Morella', 'pennsylvanica', 'Bristly locust', 'shrub', 'https://en.wikipedia.org/wiki/Myrica_pensylvanica'],    
    'Rosa_rug': ['Rosa', 'rugosa', 'Sandy Beach Rose', 'shrub', 'https://en.wikipedia.org/wiki/Rosa_rugosa'],
    'Cham_fas': ['Chamaecrista', 'fasciculata', 'Partridge Pea', 'legume', 'https://en.wikipedia.org/wiki/Chamaecrista_fasciculata'],
    'Soli_rug': ['Solidago', 'rugosa', 'Wrinkleleaf goldenrod', 'shrub', 'https://en.wikipedia.org/wiki/Solidago_rugosa'],
    'Bacc_hal': ['Baccharis', 'halimifolia', 'Groundseltree', 'shrub', 'https://en.wikipedia.org/wiki/Baccharis_halimifolia'],
    'Iva_fru_': ['Iva', 'frutescens', 'Jesuits Bark ', 'shrub', 'https://en.wikipedia.org/wiki/Iva_frutescens'],
    'Ilex_vom': ['Ilex', 'vomitoria', 'Yaupon Holly', 'evergreen shrub', 'https://en.wikipedia.org/wiki/Ilex_vomitoria'],
    'Genus_spe': ['Genus', 'species', 'vegetation', 'background', '']
}  
age_codes = {  
    'PE': ['Post Germination Emergence', 'PE'],
	#'RE': ['Re-emergence', 'RE'],
    #'RE': ['Year 1 growth', '1G'],
	#'E': ['Emergence (from seed)', 'E'],
    'E': ['Post Germination Emergence', 'PE'],
	#'D': ['Dormant', 'D'],
	'1G': ['Year 1 growth', '1G'],
    '2G': ['Year 2 growth', '2G'],
	#'1F': ['Year 1 Flowering', '1F'],
    'J': ['Juvenile', 'J'],
	'M': ['Mature', 'M']
}
principal_part_codes = {  
    'MX': ['Mix', 'MX'],
    #'S': ['Seed', 'SE'],
	#'SA': ['Shoot Apex', 'SA'],
    'SA': ['Internode Stem', 'ST'],
	'L': ['Leaf/Blade', 'L'],
	#'IS': ['Internode Stem', 'IS'],
    'ST': ['Internode Stem', 'ST'],
    'SP': ['Sprout', 'SP'],
	#'CS': ['Colar Sprout', 'CS'],
    'CS': ['Sprout', 'SP'],
	#'RS': ['Root Sprout', 'RS'],
    'RS': ['Sprout', 'SP'],
	'LG': ['Lignin', 'LG'],
	'FL': ['Flower', 'FL'],
    #'B': ['Blade', 'B'],
	'B': ['Leaf/Blade', 'L'],
    'FR': ['Fruit', 'FR'],
	#'S': ['Seed', 'SE'], #moved above because 'S' is in other codes; this is an old code
    'SE': ['Seed', 'SE'],
	#'St': ['Stalk', 'St']
}
health_codes = {
    'MH': ['Healthy/Unhealthy Mix', 'MH'],
	'DS': ['Drought Stress', 'DS'],
	'SS': ['Salt Stress (soak)', 'SS'],
    'SY': ['Salt Stress (spray)', 'SY'],
	'S': ['Stressed', 'S'],
    'LLRZ': ['LLRZ Lab Stress', 'LLRZ'],
	#'D': ['Dormant', 'D'],
    'R': ['Rust', 'R'],
    'H': ['Healthy', 'H']
}

lifecycle_codes = { 
	'D': ['Dormant', 'D'],
    'RE': ['Re-emergence', 'RE'],
    'FLG': ['Flowering', 'FLG'],
    'FRG': ['Fruiting', 'FRG'],
    "FFG": ['Fruiting and Flowering', 'FFG'],
    'N': ['Neither', 'N']
}

# data lists

d_spectra = []
d_plant = []
d_part = []
d_health = []
d_age = []
d_lifecycle = []

yd_all_dict_str = {
    'plant': d_plant,
    'age': d_age,
    'part': d_part,
    'health': d_health,
    'lifecycle': d_lifecycle
}

code_category_dict = {
    'plant': plant_codes,
    'age': age_codes,
    'part': principal_part_codes,
    'health': health_codes,
    'lifecycle': lifecycle_codes
}

In [None]:
## Find ROI data

roi_files = find_roi_files('data/pkl/rois_labeled')
print(f"Number of ROI files found: {len(roi_files)}")

In [None]:
## Prepare ROI data

# ==============================================================================
# Helper Function for Stratified Sampling
# ==============================================================================
def stratified_sample_with_min_per_roi(spectra_df, min_per_roi=50, total_samples=300):
    """
    Selects a subset of labeled ROI pixels from a DataFrame.

    Args:
        spectra_df (pd.DataFrame): DataFrame containing spectral data and a 'roi_name' column.
        min_per_roi (int): The minimum number of pixels to select from each ROI.
        total_samples (int): The total number of pixels to select.

    Returns:
        pd.DataFrame: A DataFrame containing the selected subset of ROI pixels.
    """
    # Group by ROI and sample a minimum number of pixels from each.
    # If an ROI has fewer pixels than min_per_roi, all its pixels are taken.
    guaranteed_samples = spectra_df.groupby('roi_name').apply(
        lambda x: x.sample(n=min(len(x), min_per_roi))
    ).reset_index(drop=True)

    # If a total sample size is specified and it's larger than the guaranteed sample
    if total_samples and total_samples > len(guaranteed_samples):
        remaining_to_select = total_samples - len(guaranteed_samples)
        
        # Create a pool of remaining pixels by excluding those already selected
        # We use the DataFrame index to identify unique rows
        remaining_pixels_df = spectra_df.drop(guaranteed_samples.index)

        # If there are enough remaining pixels, sample from them
        if remaining_to_select > 0 and not remaining_pixels_df.empty:
            num_to_sample_from_remaining = min(remaining_to_select, len(remaining_pixels_df))
            additional_samples = remaining_pixels_df.sample(n=num_to_sample_from_remaining)
            
            # Combine the guaranteed and additional samples
            final_selection_df = pd.concat([guaranteed_samples, additional_samples])
        else:
            final_selection_df = guaranteed_samples
    else:
        final_selection_df = guaranteed_samples

    return final_selection_df


# ==============================================================================
# Main Data Processing and Sampling Logic
# ==============================================================================

# --- STAGE 1: Aggregate all ROI data from all files ---

# This list will hold DataFrames of processed ROI data from each file
all_rois_data_list = []
resample_rois = True

for roi_filename in roi_files:
    # Unpickling the dictionary
    with open(roi_filename, 'rb') as f:
        roiData = pickle.load(f)
        roi_df = roiData.df  # The DataFrame holding all the data for the ROI

    # Create a unique identifier for sampling based on both Name and Color
    # This creates a unique integer for each ('Name', 'Color') group
    color_group_id = roi_df.groupby(['Name', 'Color']).ngroup().astype(str)
    # We append this ID to the original name to create a new, unique ROI name for sampling
    roi_df['unique_roi_name'] = roi_df['Name'] + '_' + color_group_id

    roi_spectra = roi_df.iloc[:, 4:-1].to_numpy().astype(np.float32)

    if resample_rois:
        wl_roi = roi_df.columns.to_numpy()[4:-1]
        resampler_roi = spectral.BandResampler(wl_roi, wl_img)
        print("Number of bands: ", len(wl_roi))
        roi_spectra = resampler_roi(roi_spectra.T).T

    if "crisfield" in roi_filename.lower():
        print("Running pixel-wise normalization for 'crisfield'")
        min_vals_pixel = np.min(roi_spectra, axis=1, keepdims=True)
        max_vals_pixel = np.max(roi_spectra, axis=1, keepdims=True)
        range_vals_pixel = max_vals_pixel - min_vals_pixel
        range_vals_pixel[range_vals_pixel == 0] = 1
        roi_spectra = (roi_spectra - min_vals_pixel) / range_vals_pixel

    # Create a temporary DataFrame for the current file's data
    temp_df = pd.DataFrame(roi_spectra)
    
    # This takes the raw data from the roi_df columns and assigns it by position.
    temp_df['roi_name'] = roi_df['unique_roi_name'].values
    temp_df['original_name'] = roi_df['Name'].values
    
    all_rois_data_list.append(temp_df)
    print(f"Processed and aggregated data from {roi_filename}")

# Concatenate all data into a single master DataFrame
master_roi_df = pd.concat(all_rois_data_list, ignore_index=True)

print("\n" + "="*50)
print("ROI Data Aggregation Complete.")
print(f"Total number of ROI pixels collected: {len(master_roi_df)}")
print("Pixel distribution before sampling (groups are Name + Color):")
print(master_roi_df['roi_name'].value_counts())
print("="*50 + "\n")


# --- STAGE 2: Perform Stratified Sampling on the aggregated data ---

MIN_PIXELS_PER_ROI = 30 
TOTAL_ROI_PIXELS_TO_SELECT = 300

print(f"Starting stratified sampling with min {MIN_PIXELS_PER_ROI}/group and a target total of {TOTAL_ROI_PIXELS_TO_SELECT}...")

selected_rois_df = stratified_sample_with_min_per_roi(
    master_roi_df,
    min_per_roi=MIN_PIXELS_PER_ROI,
    total_samples=TOTAL_ROI_PIXELS_TO_SELECT
)

print("\n" + "="*50)
print("Stratified Sampling Complete.")
print(f"Total number of pixels selected: {len(selected_rois_df)}")
print("Pixel distribution after sampling (groups are Name + Color):")
print(selected_rois_df['roi_name'].value_counts())
print("="*50 + "\n")


# --- STAGE 3: Populate final data structures with the selected subset ---

# Extract the selected spectra and ORIGINAL names for final processing
selected_spectra = selected_rois_df.drop(columns=['roi_name', 'original_name']).to_numpy()
selected_original_names = selected_rois_df['original_name'].to_numpy()

for i in range(len(selected_spectra)):
    roi_spectrum = selected_spectra[i]
    # Use the ORIGINAL name for metadata parsing
    name = selected_original_names[i]

    d_spectra.append(roi_spectrum)
    
    # This block is your original metadata parsing logic, which now works on the original name
    if name == 'Genus_spe_N_N_N_N':
        name = 'Genus_spe_MX_N_N_N'
    
    if name == 'Genus_species_MX_N_N_N':
        name = 'Genus_spe_MX_N_N_N'
    
    if name == 'Soli_semp_MX_M_H_FLG':
        name = 'Soli_sem_MX_M_H_FLG'
    
    if name == 'Iva_frut_MX_M_H_N':
        name = 'Iva_fru_MX_M_H_N'

    if name[-1] != '_':
        name = name + '_'
    
    class_data_dict = {}
    # Assuming code_category_dict is defined elsewhere
    for cat, codes in code_category_dict.items():
        class_data_dict[cat] = 'N'
        for key, value in codes.items():
            if cat == 'plant':
                if (name[:8].lower() == key.lower()) or (name[:9].lower() == key.lower()):
                    class_data_dict[cat] = value[0] + '_' + value[1]
            else:
                if '_' + key + '_' in name:
                    class_data_dict[cat] = value[1]
    
    for key in yd_all_dict_str:
        yd_all_dict_str[key].append(class_data_dict[key])

print("Final `d_spectra` and `yd_all_dict_str` have been populated with the sampled data.")
use_rois = True

d_spectra = np.asarray(d_spectra)
print(d_spectra.shape)

for key in yd_all_dict_str:
    yd_all_dict_str[key] = np.asarray(yd_all_dict_str[key])
    print(key, yd_all_dict_str[key].shape)

## Prepare data for training and testing

In [None]:
# --- Label Encoding and 'N' Value Replacement Functions ---

def assign_integer_labels(data_array, label_array):
    """Maps string labels to integers based on a complete list of unique labels."""
    mapping = {label: i for i, label in enumerate(label_array)}
    return np.array([mapping[x] for x in data_array])

def replace_n_with_ignore_val(integer_array, label_array):
    """Replaces the integer corresponding to the label 'N' with the IGNORE_VALUE."""
    try:
        n_index = np.where(label_array == 'N')[0][0]
        integer_array_copy = integer_array.copy()
        integer_array_copy[integer_array_copy == n_index] = IGNORE_VALUE
        return integer_array_copy
    except IndexError:
        # 'N' was not found in this label set, so no replacement is needed.
        return integer_array


# --- Processing All Labels for Consistent Encoding ---

# Create a comprehensive list of all possible labels for each task
y_plant_labels_full = np.unique(np.concatenate([plant_array_lib, yd_all_dict_str['plant']]))
y_age_labels_full = np.unique(np.concatenate([age_array_lib, yd_all_dict_str['age']]))
y_part_labels_full = np.unique(np.concatenate([part_array_lib, yd_all_dict_str['part']]))
y_health_labels_full = np.unique(np.concatenate([health_array_lib, yd_all_dict_str['health']]))
y_lifecycle_labels_full = np.unique(np.concatenate([lifecycle_array_lib, yd_all_dict_str['lifecycle']]))

# Create human-readable label maps for evaluation reports (excluding 'N' where it's an invalid label)
label_maps = {
    'plant': np.sort([l for l in y_plant_labels_full if l != 'N']),
    'age': np.sort([l for l in y_age_labels_full if l != 'N']),
    'part': np.sort([l for l in y_part_labels_full if l != 'N']),
    'health': np.sort([l for l in y_health_labels_full if l != 'N']),
    'lifecycle': y_lifecycle_labels_full # For lifecycle, 'N' is a valid class, so we don't filter it
}

# Store the label_maps dictionary in a shelve key-value store
with shelve.open('data/shelve/label_maps_store_5') as db:
    db['label_maps'] = label_maps

# Process the library labels
Yn_int_lib = assign_integer_labels(plant_array_lib, y_plant_labels_full)
Ya_int_lib = assign_integer_labels(age_array_lib, y_age_labels_full)
Yp_int_lib = assign_integer_labels(part_array_lib, y_part_labels_full)
Yh_int_lib = assign_integer_labels(health_array_lib, y_health_labels_full)
Yl_int_lib = assign_integer_labels(lifecycle_array_lib, y_lifecycle_labels_full)

y_library_tasks_orig = {
    'plant': replace_n_with_ignore_val(Yn_int_lib, y_plant_labels_full),
    'age': replace_n_with_ignore_val(Ya_int_lib, y_age_labels_full),
    'part': replace_n_with_ignore_val(Yp_int_lib, y_part_labels_full),
    'health': replace_n_with_ignore_val(Yh_int_lib, y_health_labels_full),
    'lifecycle': Yl_int_lib  # No 'N' replacement
}

# Process the nano-imagery labels
Yn_int_nano = assign_integer_labels(yd_all_dict_str['plant'], y_plant_labels_full)
Ya_int_nano = assign_integer_labels(yd_all_dict_str['age'], y_age_labels_full)
Yp_int_nano = assign_integer_labels(yd_all_dict_str['part'], y_part_labels_full)
Yh_int_nano = assign_integer_labels(yd_all_dict_str['health'], y_health_labels_full)
Yl_int_nano = assign_integer_labels(yd_all_dict_str['lifecycle'], y_lifecycle_labels_full)

y_nano_tasks_orig = {
    'plant': replace_n_with_ignore_val(Yn_int_nano, y_plant_labels_full),
    'age': replace_n_with_ignore_val(Ya_int_nano, y_age_labels_full),
    'part': replace_n_with_ignore_val(Yp_int_nano, y_part_labels_full),
    'health': replace_n_with_ignore_val(Yh_int_nano, y_health_labels_full),
    'lifecycle': Yl_int_nano # No 'N' replacement
}

print("\nDefining number of classes based on the full label sets...")

n_plant_classes = len(y_plant_labels_full)
n_age_classes = len(y_age_labels_full)
n_part_classes = len(y_part_labels_full)
n_health_classes = len(y_health_labels_full)
n_lifecycle_classes = len(y_lifecycle_labels_full)

# -- DEBUG --
print(f"Number of classes: Plant={n_plant_classes}, Age={n_age_classes}, Part={n_part_classes}, Health={n_health_classes}, Lifecycle={n_lifecycle_classes}")
print(len(label_maps['plant']), len(label_maps['age']), len(label_maps['part']), len(label_maps['health']), len(label_maps['lifecycle']))

# --- DEBUG: Print labels where plant is 'N' ---
print("\n" + "="*30)
print("DEBUG: Checking for 'N' plant labels")
print("="*30)

print("\n--- Library Labels (plant='N') ---")
lib_n_indices = np.where(plant_array_lib == 'N')[0]
for i in lib_n_indices:
    print(f"Index {i}: plant='{plant_array_lib[i]}', age='{age_array_lib[i]}', part='{part_array_lib[i]}', health='{health_array_lib[i]}', lifecycle='{lifecycle_array_lib[i]}'")

print("\n--- Nano-Imagery Labels (plant='N') ---")
nano_n_indices = np.where(yd_all_dict_str['plant'] == 'N')[0]
for i in nano_n_indices:
    print(f"Index {i}: plant='{yd_all_dict_str['plant'][i]}', age='{yd_all_dict_str['age'][i]}', part='{yd_all_dict_str['part'][i]}', health='{yd_all_dict_str['health'][i]}', lifecycle='{yd_all_dict_str['lifecycle'][i]}'")
print("="*30 + "\n")

## REUSABLE MODEL & EVALUATION FUNCTIONS

In [None]:
# ==============================================================================
# 3. REUSABLE MODEL & EVALUATION FUNCTIONS
# ==============================================================================

def build_multitask_cnn(input_shape, n_plant, n_age, n_part, n_health, n_lifecycle):
    """Builds the multi-task 1D CNN with named backbone layers for transfer learning."""
    inputs = keras.Input(shape=input_shape, name='spectrum_input')

    # --- Shared Feature Extractor (1D CNN Backbone) ---
    x = layers.Conv1D(filters=32, kernel_size=7, activation='relu', padding='same', name='backbone_conv1')(inputs)
    x = layers.BatchNormalization(name='backbone_bn1')(x)
    x = layers.MaxPooling1D(pool_size=3, name='backbone_pool1')(x)
    x = layers.Dropout(0.25, name='backbone_drop1')(x)

    x = layers.Conv1D(filters=64, kernel_size=5, activation='relu', padding='same', name='backbone_conv2')(x)
    x = layers.BatchNormalization(name='backbone_bn2')(x)
    x = layers.MaxPooling1D(pool_size=3, name='backbone_pool2')(x)
    x = layers.Dropout(0.25, name='backbone_drop2')(x)

    x = layers.Conv1D(filters=128, kernel_size=3, activation='relu', padding='same', name='backbone_conv3')(x)
    x = layers.BatchNormalization(name='backbone_bn3')(x)
    x = layers.MaxPooling1D(pool_size=3, name='backbone_pool3')(x)
    x = layers.Dropout(0.3, name='backbone_drop3')(x)

    x = layers.Flatten(name='backbone_flatten')(x)
    shared_features = layers.Dense(128, activation='relu', name='backbone_dense')(x)
    shared_features_bn = layers.BatchNormalization(name='backbone_dense_bn')(shared_features)
    shared_features_drop = layers.Dropout(0.5, name='backbone_dense_drop')(shared_features_bn)

    # --- Task-Specific Output Heads ---
    plant_output = layers.Dense(64, activation='relu')(shared_features_drop)
    plant_output = layers.Dense(n_plant, activation='softmax', name='plant_output')(plant_output)

    age_output = layers.Dense(32, activation='relu')(shared_features_drop)
    age_output = layers.Dense(n_age, activation='softmax', name='age_output')(age_output)

    part_output = layers.Dense(32, activation='relu')(shared_features_drop)
    part_output = layers.Dense(n_part, activation='softmax', name='part_output')(part_output)

    health_output = layers.Dense(32, activation='relu')(shared_features_drop)
    health_output = layers.Dense(n_health, activation='softmax', name='health_output')(health_output)

    lifecycle_output = layers.Dense(32, activation='relu')(shared_features_drop)
    lifecycle_output = layers.Dense(n_lifecycle, activation='softmax', name='lifecycle_output')(lifecycle_output)

    # --- Build and return the final model ---
    model = keras.Model(
        inputs=inputs,
        outputs={
            'plant_output': plant_output,
            'age_output': age_output,
            'part_output': part_output,
            'health_output': health_output,
            'lifecycle_output': lifecycle_output
        },
        name="multitask_cnn"
    )
    model.summary()
    return model


def evaluate_and_report(model, X_test_scaled, y_test_orig, y_test_dict_keras, sample_weights_test, scenario_name):
    """
    Generates predictions, detailed classification reports, and confusion matrices for a trained model.
    Saves reports to CSV files.
    """
    print(f"\n--- Evaluating on Test Set for {scenario_name} ---")

    # --- 1. Keras Evaluation (Quick Summary) ---
    print("\nRunning model.evaluate() for a quick summary...")
    results = model.evaluate(
        X_test_scaled,
        y_test_dict_keras,
        sample_weight=sample_weights_test,
        batch_size=BATCH_SIZE,
        verbose=0,
        return_dict=True
    )

    print("\nTest Set Keras Evaluation Results:")
    print(f"Overall Loss (Weighted Sum): {results['loss']:.4f}")
    print("\nTest Weighted Metrics (Accuracy ignoring invalid samples):")
    for name in OUTPUT_NAMES:
        metric_key = f"{name}_weighted_sparse_categorical_accuracy"
        if metric_key in results:
            print(f"  {name.replace('_output', '').capitalize()}: {results[metric_key]:.4f}")
        else:
            # Fallback for slightly different key names
            metric_key_alt = f"weighted_{name}_sparse_categorical_accuracy"
            if metric_key_alt in results:
                print(f"  {name.replace('_output', '').capitalize()} (alt key): {results[metric_key_alt]:.4f}")
            else:
                print(f"  {name.replace('_output', '').capitalize()}: Weighted metric key not found in results.")

    # --- 2. Generate Predictions for Detailed Reports ---
    print("\n\n--- Generating Detailed Reports and Visualizations ---")
    y_pred_probs = model.predict(X_test_scaled, batch_size=BATCH_SIZE, verbose=0)
    
    # Ensure predictions is a dictionary (Keras usually returns dict for multi-output)
    if not isinstance(y_pred_probs, dict):
        y_pred_probs = dict(zip(model.output_names, y_pred_probs))

    y_pred_labels = {name: np.argmax(y_pred_probs[name], axis=1) for name in OUTPUT_NAMES}
    
    overall_scores = {'accuracy': [], 'precision': [], 'recall': [], 'f1': []}

    # --- 3. Generate a Detailed Report for Each Task ---
    for task_name in TASK_NAMES:
        output_name = f"{task_name}_output"
        
        print(f"\n\n{'='*45}\nMetrics for: {task_name.capitalize()}\n{'='*45}")

        # Get true labels, predicted labels, and sample weights
        y_true_all = y_test_orig[task_name]
        y_pred_all = y_pred_labels[output_name]
        weights = sample_weights_test[output_name]
        original_labels = label_maps[task_name]

        # Conditional Filtering: Special handling for 'lifecycle' vs. other tasks
        if task_name == 'lifecycle':
            # Include 'N' category for lifecycle task
            y_true_valid = y_true_all
            y_pred_valid = y_pred_all
            valid_indices = list(range(len(original_labels)))
            valid_string_labels = list(original_labels)
            print(f"Including 'N' category. Evaluating on {len(y_true_valid)} samples.")
        else:
            # Exclude 'N' category for all other tasks using the weights
            valid_mask = (weights == 1.0)
            y_true_valid = y_true_all[valid_mask]
            y_pred_valid = y_pred_all[valid_mask]
            
            try:
                n_index = np.where(original_labels == 'N')[0][0]
                valid_indices = [i for i in range(len(original_labels)) if i != n_index]
                valid_string_labels = [label for label in original_labels if label != 'N']
            except IndexError: # 'N' category not found
                valid_indices = list(range(len(original_labels)))
                valid_string_labels = list(original_labels)
            print(f"Excluding 'N' category. Evaluating on {len(y_true_valid)} samples.")

        if len(y_true_valid) == 0:
            print(f"No valid samples to evaluate for task '{task_name}'. Skipping.")
            continue
            
        # Determine the final set of labels present in the data for robust reporting
        present_labels_in_data = np.unique(np.concatenate((y_true_valid, y_pred_valid)))
        final_cm_indices = [idx for idx in valid_indices if idx in present_labels_in_data]
        final_cm_labels = [label for idx, label in zip(valid_indices, valid_string_labels) if idx in final_cm_indices]
        
        if not final_cm_labels:
            print(f"No valid classes remain for task '{task_name}' after filtering. Skipping.")
            continue

        # Calculate overall metrics for this task
        accuracy = accuracy_score(y_true_valid, y_pred_valid)
        precision = precision_score(y_true_valid, y_pred_valid, average='weighted', zero_division=0)
        recall = recall_score(y_true_valid, y_pred_valid, average='weighted', zero_division=0)
        f1 = f1_score(y_true_valid, y_pred_valid, average='weighted', zero_division=0)
        
        overall_scores['accuracy'].append(accuracy)
        overall_scores['precision'].append(precision)
        overall_scores['recall'].append(recall)
        overall_scores['f1'].append(f1)

        # Generate, display, and save the DataFrame report
        print("\n--- DataFrame Report ---")
        report_dict = classification_report(
            y_true_valid,
            y_pred_valid,
            labels=final_cm_indices,
            target_names=final_cm_labels,
            zero_division=0,
            output_dict=True
        )
        report_df = pd.DataFrame(report_dict).transpose()
        display(report_df)

        if EXPORT_TO_CSV:
            csv_filename = f'reports/classification_report_{scenario_name}_{task_name}.csv'
            report_df.to_csv(csv_filename)
            print(f"DataFrame report saved to '{csv_filename}'")

        # Calculate and plot the Confusion Matrix
        print("\n--- Confusion Matrix ---")
        cm = confusion_matrix(y_true_valid, y_pred_valid, labels=final_cm_indices)
        plt.figure(figsize=(max(8, len(final_cm_labels)*0.8), max(6, len(final_cm_labels)*0.6)))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=final_cm_labels,
                    yticklabels=final_cm_labels)
        plt.title(f'Confusion Matrix - Task: {task_name.capitalize()}')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.xticks(rotation=45, ha='right')
        plt.yticks(rotation=0)
        plt.tight_layout()
        plt.show()
            
    # --- 4. Display and Save Overall Model Performance ---
    print(f"\n\n{'='*55}\n--- Overall Model Performance (Averaged Across Tasks) ---\n{'='*55}")
    overall_metrics_dict = {
        'Overall Accuracy': [np.mean(overall_scores['accuracy'])],
        'Overall Precision (Weighted Avg)': [np.mean(overall_scores['precision'])],
        'Overall Recall (Weighted Avg)': [np.mean(overall_scores['recall'])],
        'Overall F1-Score (Weighted Avg)': [np.mean(overall_scores['f1'])]
    }
    overall_df = pd.DataFrame(overall_metrics_dict).T
    overall_df.columns = ['Average Score']
    display(overall_df)
    
    if EXPORT_TO_CSV:
        overall_csv_filename = f'reports/classification_report_{scenario_name}_overall.csv'
        overall_df.to_csv(overall_csv_filename)
        print(f"\nOverall report for {scenario_name} saved to '{overall_csv_filename}'")

    # --- 5. Optional: Labeled Codes and Counts in Test Set for Each Species ---
    print(f"\n\n{'='*60}\n--- Detailed Label Counts in Test Set by Plant Species ---\n{'='*60}")
    from collections import Counter

    for i, plant_name in enumerate(label_maps['plant']):
        print(f"\n--- Analysis for Species: {plant_name} ---")
        
        # Find indices corresponding to the current plant species
        selected_indices = np.where(y_test_orig['plant'] == i)[0]
        
        if len(selected_indices) == 0:
            print("No test samples found for this species.")
            continue
            
        # Filter the original label dictionaries for this species
        y_test_filtered_by_species = {key: arr[selected_indices] for key, arr in y_test_orig.items()}
        
        # Count the occurrences of each label for each task
        for task, labels in y_test_filtered_by_species.items():
            if task == 'plant': continue # Skip counting the plant itself
            
            print(f"  Task '{task}':")
            counts = Counter(labels)
            for value_idx, count in sorted(counts.items()):
                # Map integer index back to string label
                if value_idx == IGNORE_VALUE:
                    label_str = "IGNORED"
                else:
                    label_str = label_maps[task][value_idx]
                print(f"    {label_str}: {count}")

## DATA SPLITTING AND PREP FOR FINE-TUNING (NANO IMAGERY)

In [None]:
# ==============================================================================
# 4. DATA SPLITTING AND PREP FOR FINE-TUNING (NANO IMAGERY)
# ==============================================================================

print("\nSplitting nano-imagery data for fine-tuning...")

nano_indices = np.arange(len(d_spectra))
train_idx, test_idx = train_test_split(nano_indices, test_size=0.20, random_state=42)
train_idx, val_idx = train_test_split(train_idx, test_size=0.20, random_state=42) # 20% of the 80% is 16%

X_nano_train, X_nano_val, X_nano_test = d_spectra[train_idx], d_spectra[val_idx], d_spectra[test_idx]

# Split the original integer labels (which include the -1 ignore value)
y_nano_train_orig = {task: y_nano_tasks_orig[task][train_idx] for task in TASK_NAMES}
y_nano_val_orig = {task: y_nano_tasks_orig[task][val_idx] for task in TASK_NAMES}
y_nano_test_orig = {task: y_nano_tasks_orig[task][test_idx] for task in TASK_NAMES}

# Create Keras-ready labels (replace -1 with 0 for loss calculation)
y_nano_train_keras = {f"{t}_output": np.maximum(0, y_nano_train_orig[t]) for t in TASK_NAMES}
y_nano_val_keras = {f"{t}_output": np.maximum(0, y_nano_val_orig[t]) for t in TASK_NAMES}
y_nano_test_keras = {f"{t}_output": np.maximum(0, y_nano_test_orig[t]) for t in TASK_NAMES}

# Create sample weights (0 for ignored labels, 1 for valid labels)
sample_weights_train = {f"{t}_output": (y_nano_train_orig[t] != IGNORE_VALUE).astype(np.float32) for t in TASK_NAMES}
sample_weights_val = {f"{t}_output": (y_nano_val_orig[t] != IGNORE_VALUE).astype(np.float32) for t in TASK_NAMES}
sample_weights_test = {f"{t}_output": (y_nano_test_orig[t] != IGNORE_VALUE).astype(np.float32) for t in TASK_NAMES}

print(f"Nano data split: Train={len(X_nano_train)}, Val={len(X_nano_val)}, Test={len(X_nano_test)}")

# Define the shared EarlyStopping callback
early_stopping = keras.callbacks.EarlyStopping(
    monitor='val_loss',         # Monitor validation loss
    patience=PATIENCE,          # Number of epochs with no improvement to wait
    restore_best_weights=True,  # Restore weights from the epoch with the best val_loss
    verbose=1
)

## SCENARIO 2.2a: BASELINE TRANSFER LEARNING

In [None]:
# ==============================================================================
# 5. SCENARIO 2.2a: BASELINE TRANSFER LEARNING
# ==============================================================================

print("\n" + "="*70)
print("--- Starting Scenario 2.2a: Baseline Transfer Learning ---")
print("="*70 + "\n")

# --- 2.2a Step 1: Pre-train on down-sampled library data ---
print("Step 2.2a.1: Preparing and splitting library data for pre-training...")

# --- Split the library data into training and validation sets ---
library_indices = np.arange(len(spectra_resampled))
lib_train_idx, lib_val_idx = train_test_split(library_indices, test_size=0.15, random_state=42)

X_lib_train = spectra_resampled[lib_train_idx]
X_lib_val = spectra_resampled[lib_val_idx]

# --- Scale the data ---

scaler_2a_pretrain = StandardScaler()

X_lib_train_scaled = scaler_2a_pretrain.fit_transform(X_lib_train)[..., np.newaxis]
X_lib_val_scaled = scaler_2a_pretrain.transform(X_lib_val)[..., np.newaxis]
input_shape_259 = X_lib_train_scaled.shape[1:]

# --- Prepare labels and sample weights for both training and validation sets ---
y_lib_train_orig = {task: y_library_tasks_orig[task][lib_train_idx] for task in TASK_NAMES}
y_lib_val_orig = {task: y_library_tasks_orig[task][lib_val_idx] for task in TASK_NAMES}

y_lib_train_keras = {f"{t}_output": np.maximum(0, y_lib_train_orig[t]) for t in TASK_NAMES}
y_lib_val_keras = {f"{t}_output": np.maximum(0, y_lib_val_orig[t]) for t in TASK_NAMES}

sample_weights_lib_train = {f"{t}_output": (y_lib_train_orig[t] != IGNORE_VALUE).astype(np.float32) for t in TASK_NAMES}
sample_weights_lib_val = {f"{t}_output": (y_lib_val_orig[t] != IGNORE_VALUE).astype(np.float32) for t in TASK_NAMES}


pretrain_model_2a = build_multitask_cnn(
    input_shape_259, n_plant_classes, n_age_classes, n_part_classes, n_health_classes, n_lifecycle_classes
)

if np.isnan(X_lib_train_scaled).any():
    print("FATAL ERROR: NaN values found in the scaled training data!")

# --- Define the explicit, multi-output losses and metrics (from your original code) ---
losses = {name: 'sparse_categorical_crossentropy' for name in OUTPUT_NAMES}
metrics = {name: 'sparse_categorical_accuracy' for name in OUTPUT_NAMES}

pretrain_model_2a.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=losses,
    loss_weights=LOSS_WEIGHTS,
    metrics=metrics,
    weighted_metrics=metrics
)

print("Step 2.2a.2: Pre-training on 259-band library data...")

pretrain_model_2a.fit(
    X_lib_train_scaled,
    y_lib_train_keras,
    sample_weight=sample_weights_lib_train,
    epochs=PRETRAIN_EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=(X_lib_val_scaled, y_lib_val_keras, sample_weights_lib_val),
    callbacks=[early_stopping],
    verbose=1
)

# --- Save the pre-trained model and the scaler ---
print("\nStep 2.2a.3: Saving pre-trained model and scaler...")
os.makedirs('models', exist_ok=True)
os.makedirs('scalers', exist_ok=True)
pretrain_model_2a.save('models/pretrain_model_2a.keras')
with open('scalers/scaler_2a_pretrain.pkl', 'wb') as f:
    pickle.dump(scaler_2a_pretrain, f)
print("Pre-trained model and scaler saved successfully.")

In [None]:
# ==============================================================================
# 5. SCENARIO 2.2a: OPTIONALLY LOAD PRE-TRAINED MODEL
# ==============================================================================
import os
import pickle
from tensorflow import keras

# Define file paths
pretrain_model_path = 'models/pretrain_model_2a.keras'
scaler_path = 'scalers/scaler_2a_pretrain.pkl'

# Check if the pre-trained model file exists and load it
if os.path.exists(pretrain_model_path) and os.path.exists(scaler_path):
    print(f"Found pre-trained model at: {pretrain_model_path}")
    print("Loading model and scaler from files...")
    pretrain_model_2a = keras.models.load_model(pretrain_model_path)
    with open(scaler_path, 'rb') as f:
        scaler_2a_pretrain = pickle.load(f)
    print("Successfully loaded pre-trained model and scaler.")
    # Also reload input shape from the loaded scaler and model
    input_shape_259 = pretrain_model_2a.input_shape[1:]

else:
    print("Pre-trained model file not found.")
    print("Proceeding with the model just trained in the previous step.")

In [None]:
# --- 2.2a Step 2: Build fine-tuning model ---

print("\nStep 2.2a.2: Building the fine-tuning model...")

backbone_2a = keras.Model(
    inputs=pretrain_model_2a.input,
    outputs=pretrain_model_2a.get_layer('backbone_dense_drop').output,
    name="pretrained_backbone_259"
)
backbone_2a.trainable = False  # Freeze the backbone

# Build new model with frozen backbone and new, trainable heads
nano_inputs_2a = keras.Input(shape=input_shape_259, name='nano_input_259')
features_2a = backbone_2a(nano_inputs_2a, training=False)

plant_head_2a = layers.Dense(64, activation='relu')(features_2a)
plant_output_2a = layers.Dense(n_plant_classes, activation='softmax', name='plant_output')(plant_head_2a)

age_head_2a = layers.Dense(32, activation='relu')(features_2a)
age_output_2a = layers.Dense(n_age_classes, activation='softmax', name='age_output')(age_head_2a)

part_head_2a = layers.Dense(32, activation='relu')(features_2a)
part_output_2a = layers.Dense(n_part_classes, activation='softmax', name='part_output')(part_head_2a)

health_head_2a = layers.Dense(32, activation='relu')(features_2a)
health_output_2a = layers.Dense(n_health_classes, activation='softmax', name='health_output')(health_head_2a)

lifecycle_head_2a = layers.Dense(32, activation='relu')(features_2a)
lifecycle_output_2a = layers.Dense(n_lifecycle_classes, activation='softmax', name='lifecycle_output')(lifecycle_head_2a)

finetune_model_2a = keras.Model(
    inputs=nano_inputs_2a,
    outputs={
        'plant_output': plant_output_2a,
        'age_output': age_output_2a,
        'part_output': part_output_2a,
        'health_output': health_output_2a,
        'lifecycle_output': lifecycle_output_2a
    }
)

# --- 2.2a Step 3: Compile and fine-tune on nano imagery ---
print("\nStep 2.2a.3: Fine-tuning on nano imagery...")

# Use the loaded or just-trained pre-training scaler to transform the new data
X_nano_train_scaled = scaler_2a_pretrain.transform(X_nano_train)[..., np.newaxis]
X_nano_val_scaled = scaler_2a_pretrain.transform(X_nano_val)[..., np.newaxis]
X_nano_test_scaled = scaler_2a_pretrain.transform(X_nano_test)[..., np.newaxis]

# Use the same explicit losses and metrics dictionaries as in the pre-training step
losses = {name: 'sparse_categorical_crossentropy' for name in OUTPUT_NAMES}
metrics = {name: 'sparse_categorical_accuracy' for name in OUTPUT_NAMES}

finetune_model_2a.compile(
    optimizer=keras.optimizers.Adam(learning_rate=FT_LEARNING_RATE),
    loss=losses,
    loss_weights=LOSS_WEIGHTS,
    metrics=metrics,
    weighted_metrics=metrics
)

print("\nFine-tuning model compiled with weighted metrics.")

finetune_model_2a.fit(
    X_nano_train_scaled,
    y_nano_train_keras,
    sample_weight=sample_weights_train,
    validation_data=(X_nano_val_scaled, y_nano_val_keras, sample_weights_val),
    epochs=FINETUNE_EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=[early_stopping],
    verbose=1
)

# --- 2.2a Step 4: Full Evaluation and Saving ---
evaluate_and_report(
    finetune_model_2a, 
    X_nano_test_scaled, 
    y_nano_test_orig, 
    y_nano_test_keras, 
    sample_weights_test, 
    "Scenario_2a"
)

# The final fine-tuned model and the scaler used for it are saved.
finetune_model_2a.save('models/finetune_model_2a.keras')
with open('scalers/scaler_2a_finetune.pkl', 'wb') as f:
    pickle.dump(scaler_2a_pretrain, f)

print("\n--- Scenario 2.2a Finished ---")

## SCENARIO 2.2b: REFINED TRANSFER LEARNING WITH ADAPTER

In [None]:
# ==============================================================================
# 6. SCENARIO 2.2b: REFINED TRANSFER LEARNING WITH ADAPTER
# ==============================================================================

print("\n" + "="*70)
print("--- Starting Scenario 2.2b: Refined Transfer Learning with Adapter ---")
print("="*70 + "\n")

# --- 2.2b Step 1: Pre-train on high-resolution data ---
print("Step 2.2b.1: Preparing and splitting high-resolution library data for pre-training...")

# Find indices for the wavelength range 425nm to 900nm
index_425nm = np.where(wl_lib == 425)[0][0]
index_900nm = np.where(wl_lib == 900)[0][0]

# --- Split the high-resolution library data into training and validation sets ---
X_library_cropped_highres = spectra[:, index_425nm:index_900nm]
library_indices_highres = np.arange(len(X_library_cropped_highres))
lib_train_idx, lib_val_idx = train_test_split(library_indices_highres, test_size=0.15, random_state=42)

X_lib_train_highres = X_library_cropped_highres[lib_train_idx]
X_lib_val_highres = X_library_cropped_highres[lib_val_idx]

# --- Scale the data (Fit ONLY on training data) ---
scaler_2b_pretrain = StandardScaler()
X_lib_train_scaled_highres = scaler_2b_pretrain.fit_transform(X_lib_train_highres)[..., np.newaxis]
X_lib_val_scaled_highres = scaler_2b_pretrain.transform(X_lib_val_highres)[..., np.newaxis]
input_shape_475 = X_lib_train_scaled_highres.shape[1:]

# --- Prepare labels and sample weights for both splits ---
y_lib_train_orig = {task: y_library_tasks_orig[task][lib_train_idx] for task in TASK_NAMES}
y_lib_val_orig = {task: y_library_tasks_orig[task][lib_val_idx] for task in TASK_NAMES}
y_lib_train_keras = {f"{t}_output": np.maximum(0, y_lib_train_orig[t]) for t in TASK_NAMES}
y_lib_val_keras = {f"{t}_output": np.maximum(0, y_lib_val_orig[t]) for t in TASK_NAMES}
sample_weights_lib_train = {f"{t}_output": (y_lib_train_orig[t] != IGNORE_VALUE).astype(np.float32) for t in TASK_NAMES}
sample_weights_lib_val = {f"{t}_output": (y_lib_val_orig[t] != IGNORE_VALUE).astype(np.float32) for t in TASK_NAMES}

# --- Build, Compile, and Fit ---
pretrain_model_2b = build_multitask_cnn(
    input_shape_475, n_plant_classes, n_age_classes, n_part_classes, n_health_classes, n_lifecycle_classes
)

# Use the explicit compile call
losses = {name: 'sparse_categorical_crossentropy' for name in OUTPUT_NAMES}
metrics = {name: 'sparse_categorical_accuracy' for name in OUTPUT_NAMES}
pretrain_model_2b.compile(
    optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
    loss=losses,
    loss_weights=LOSS_WEIGHTS,
    metrics=metrics,
    weighted_metrics=metrics
)

print("Step 2.2b.2: Pre-training on 475-band library data...")
# Use the explicit validation_data tuple
pretrain_model_2b.fit(
    X_lib_train_scaled_highres,
    y_lib_train_keras,
    sample_weight=sample_weights_lib_train,
    epochs=PRETRAIN_EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=(X_lib_val_scaled_highres, y_lib_val_keras, sample_weights_lib_val),
    callbacks=[early_stopping],
    verbose=1
)

# --- Save the pre-trained model and the scaler ---
print("\nStep 2.2b.3: Saving pre-trained model and scaler...")
os.makedirs('models', exist_ok=True)
os.makedirs('scalers', exist_ok=True)
pretrain_model_2b.save('models/pretrain_model_2b.keras')
with open('scalers/scaler_2b_pretrain.pkl', 'wb') as f:
    pickle.dump(scaler_2b_pretrain, f)
print("Pre-trained model and scaler saved successfully.")

In [None]:
# ==============================================================================
# 6. SCENARIO 2.2b: OPTIONALLY LOAD PRE-TRAINED MODEL
# ==============================================================================
import os
import pickle
from tensorflow import keras

# Define file paths
pretrain_model_path = 'models/pretrain_model_2b.keras'
scaler_path = 'scalers/scaler_2b_pretrain.pkl'

# Check if the pre-trained model file exists and load it
if os.path.exists(pretrain_model_path) and os.path.exists(scaler_path):
    print(f"Found pre-trained model at: {pretrain_model_path}")
    print("Loading model and scaler from files...")
    pretrain_model_2b = keras.models.load_model(pretrain_model_path)
    with open(scaler_path, 'rb') as f:
        scaler_2b_pretrain = pickle.load(f)
    print("Successfully loaded pre-trained model and scaler.")
    # Also reload input shape from the loaded model
    input_shape_475 = pretrain_model_2b.input_shape[1:]

else:
    print("Pre-trained model file not found.")
    print("Proceeding with the model just trained in the previous step.")

In [None]:
# --- 2.2b Step 2: Build fine-tuning model with an Adapter Layer ---
print("\nStep 2.2b.4: Building the fine-tuning model with an Adapter Layer...")

backbone_475 = keras.Model(
    inputs=pretrain_model_2b.input,
    outputs=pretrain_model_2b.get_layer('backbone_dense_drop').output,
    name="pretrained_backbone_475"
)
backbone_475.trainable = False # Freeze the high-res backbone

# Build the adapter and new heads
nano_inputs_2b = keras.Input(shape=(259, 1), name='nano_input_259')

# This is the trainable adapter: it learns to map from 259 to 475 bands
flatten_layer = layers.Flatten()(nano_inputs_2b)
adapter_output = layers.Dense(475, activation='relu', name='adapter_259_to_475')(flatten_layer)
reshaped_for_cnn = layers.Reshape((475, 1))(adapter_output)

# Pass the adapted input through the frozen backbone
features_2b = backbone_475(reshaped_for_cnn, training=False)

# Add new, trainable output heads
plant_head_2b = layers.Dense(64, activation='relu')(features_2b)
plant_output_2b = layers.Dense(n_plant_classes, activation='softmax', name='plant_output')(plant_head_2b)

age_head_2b = layers.Dense(32, activation='relu')(features_2b)
age_output_2b = layers.Dense(n_age_classes, activation='softmax', name='age_output')(age_head_2b)

part_head_2b = layers.Dense(32, activation='relu')(features_2b)
part_output_2b = layers.Dense(n_part_classes, activation='softmax', name='part_output')(part_head_2b)

health_head_2b = layers.Dense(32, activation='relu')(features_2b)
health_output_2b = layers.Dense(n_health_classes, activation='softmax', name='health_output')(health_head_2b)

lifecycle_head_2b = layers.Dense(32, activation='relu')(features_2b)
lifecycle_output_2b = layers.Dense(n_lifecycle_classes, activation='softmax', name='lifecycle_output')(lifecycle_head_2b)

finetune_model_2b = keras.Model(
    inputs=nano_inputs_2b,
    outputs={
        'plant_output': plant_output_2b,
        'age_output': age_output_2b,
        'part_output': part_output_2b,
        'health_output': health_output_2b,
        'lifecycle_output': lifecycle_output_2b
    }
)

# --- 2.2b Step 3: Compile and fine-tune on nano imagery ---
print("\nStep 2.2b.5: Fine-tuning on nano imagery...")

# A new scaler is needed because the input data (259 bands) is different from pre-training (475 bands)
scaler_2b_finetune = StandardScaler()
X_nano_train_scaled_2b = scaler_2b_finetune.fit_transform(X_nano_train)[..., np.newaxis]
X_nano_val_scaled_2b = scaler_2b_finetune.transform(X_nano_val)[..., np.newaxis]
X_nano_test_scaled_2b = scaler_2b_finetune.transform(X_nano_test)[..., np.newaxis]

losses = {name: 'sparse_categorical_crossentropy' for name in OUTPUT_NAMES}
metrics = {name: 'sparse_categorical_accuracy' for name in OUTPUT_NAMES}

finetune_model_2b.compile(
    optimizer=keras.optimizers.Adam(learning_rate=FT_LEARNING_RATE),
    loss=losses,
    loss_weights=LOSS_WEIGHTS,
    metrics=metrics,
    weighted_metrics=metrics
)

finetune_model_2b.fit(
    X_nano_train_scaled_2b,
    y_nano_train_keras,
    sample_weight=sample_weights_train,
    validation_data=(X_nano_val_scaled_2b, y_nano_val_keras, sample_weights_val),
    epochs=FINETUNE_EPOCHS,
    batch_size=BATCH_SIZE,
    callbacks=[early_stopping],
    verbose=1
)


# --- 2.2b Step 4: Full Evaluation and Saving ---
evaluate_and_report(
    finetune_model_2b, 
    X_nano_test_scaled_2b, 
    y_nano_test_orig, 
    y_nano_test_keras, 
    sample_weights_test, 
    "Scenario_2b"
)

finetune_model_2b.save('models/finetune_model_2b.keras')
with open('scalers/scaler_2b_finetune.pkl', 'wb') as f:
    pickle.dump(scaler_2b_finetune, f)

print("\n--- Scenario 2.2b Finished ---")