## In this notebook the FNet optimization results are validating against train and heldout data

In [None]:
import pathlib
import sys
import yaml
from collections import defaultdict

import pandas as pd
import torch
from torch.utils.data import DataLoader
import mlflow
import mlflow.artifacts
import optuna
from optuna.visualization import plot_param_importances, plot_optimization_history
import joblib

## Read config

In [None]:
with open(pathlib.Path('.').absolute().parent.parent / "config.yml", "r") as file:
    config = yaml.safe_load(file)

## Import virtual_stain_flow software 

In [None]:
sys.path.append(config['paths']['software_path'])
print(str(pathlib.Path('.').absolute().parent.parent))

## Dataset
from virtual_stain_flow.datasets.PatchDataset import PatchDataset
from virtual_stain_flow.datasets.CachedDataset import CachedDataset

## FNet training
from virtual_stain_flow.models.fnet import FNet
from virtual_stain_flow.trainers.Trainer import Trainer

## wGaN training
from virtual_stain_flow.models.unet import UNet
from virtual_stain_flow.models.discriminator import GlobalDiscriminator
from virtual_stain_flow.trainers.WGaNTrainer import WGaNTrainer

## wGaN losses
from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss
from virtual_stain_flow.losses.DiscriminatorLoss import DiscriminatorLoss
from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss

from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize
from virtual_stain_flow.transforms.PixelDepthTransform import PixelDepthTransform

## Metrics
from virtual_stain_flow.metrics.MetricsWrapper import MetricsWrapper
from virtual_stain_flow.metrics.PSNR import PSNR
from virtual_stain_flow.metrics.SSIM import SSIM

## callback
from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger
from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePatchPlot

## Define paths and other train parameters

In [None]:
## Loaddata for train and heldout set
LOADDATA_FILE_PATH = pathlib.Path('.').absolute().parent.parent \
    / '0.data_preprocessing' / 'data_split_loaddata' / 'loaddata_train.csv'
assert LOADDATA_FILE_PATH.exists()
LOADDATA_EVAL_FILE_PATH = pathlib.Path('.').absolute().parent.parent \
    / '0.data_preprocessing' / 'data_split_loaddata' / 'loaddata_heldout.csv'
assert LOADDATA_EVAL_FILE_PATH.exists()

## Corresponding sc features directory containing cell coordiantes used for patch generation
SC_FEATURES_DIR = pathlib.Path(config['paths']['sc_features_path'])

## Optimization Output Saved under these directories
MLFLOW_DIR = pathlib.Path('.').absolute() / 'optuna_mlflow'
assert MLFLOW_DIR.exists()

OPTUNA_JOBLIB_DIR = pathlib.Path('.').absolute() / 'optuna_joblib'
assert OPTUNA_JOBLIB_DIR.exists()

## Validation Output Path
VALIDATION_OUTPUT_PATH = pathlib.Path('.').absolute() / 'Validation'
VALIDATION_OUTPUT_PATH.mkdir(exist_ok=True)

## Patch size definition
PATCH_SIZE = 256

## Channels for input and target are read from config
INPUT_CHANNEL_NAMES = config['data']['input_channel_keys']
TARGET_CHANNEL_NAMES = config['data']['target_channel_keys']

## Defines how the train data will be divided to train models on two levels of confluence

In [None]:
DATA_GROUPING = {
    'high_confluence': {
        'seeding_density': [12_000, 8_000]
    },
    # 'low_confluence': {
    #     'seeding_density': [4_000, 2_000, 1_000]
    # }
}

In [None]:
mlflow.set_tracking_uri(MLFLOW_DIR / 'mlruns')

mlflow_results = {}
optuna_results = defaultdict(dict)

for confluence_group_name, _ in DATA_GROUPING.items():
    ## Access relevant optimization result and logs by confluence
    experiment_name = f'FNet_optimize_{confluence_group_name}'
    experiment = mlflow.get_experiment_by_name(experiment_name)
    mlflow_results[confluence_group_name] = mlflow.search_runs(experiment_ids=[experiment.experiment_id])
    
    for channel_name in TARGET_CHANNEL_NAMES:
        optuna_study_path = OPTUNA_JOBLIB_DIR / f"FNet_optimize_{channel_name}_{confluence_group_name}.joblib"
        study = joblib.load(optuna_study_path)
        optuna_results[confluence_group_name][channel_name] = study

        print(f"Optuna study {channel_name} {confluence_group_name}:")
        plot_param_importances(study).show()
        plot_optimization_history(study).show()

In [None]:
mlflow_results['high_confluence']

In [None]:
from virtual_stain_flow.evaluation.plot_utils import plot_patches
from virtual_stain_flow.evaluation.evaluation_utils import evaluate_per_image_metric
from virtual_stain_flow.evaluation.predict_utils import predict_image

In [None]:
EVAL_DEVICE = 'cpu'
EVAL_METRICS = [PSNR(_metric_name='psnr'), SSIM(_metric_name='ssim')]

## All validation results will be concatenated into this dataframe for convenience of comparison
all_metrics_df = pd.DataFrame()

## Iterate over train and heldout
for datasplit, loaddata_df in zip(
    ['train', 'heldout'], 
    [pd.read_csv(LOADDATA_FILE_PATH), pd.read_csv(LOADDATA_EVAL_FILE_PATH)]):

    DATASPLIT_VALIDATION_PATH = VALIDATION_OUTPUT_PATH / datasplit
    DATASPLIT_VALIDATION_PATH.mkdir(exist_ok=True)

    ## Iterate over confluence groups
    for confluence_group_name, conditions in DATA_GROUPING.items():

        CONFLUENCE_VALIDATION_PATH = DATASPLIT_VALIDATION_PATH / confluence_group_name
        CONFLUENCE_VALIDATION_PATH.mkdir(exist_ok=True)

        ## Subset loaddata to confluence  group
        loaddata_condition_df = loaddata_df.copy()
        for condition, values in conditions.items():
            loaddata_condition_df = loaddata_condition_df[
                loaddata_condition_df[condition].isin(values)
            ]

        ## Collect corresponding sc features
        sc_features = pd.DataFrame()
        for plate in loaddata_condition_df['Metadata_Plate'].unique():
            sc_features_parquet = SC_FEATURES_DIR / f'{plate}_sc_normalized.parquet'
            if not sc_features_parquet.exists():
                print(f'{sc_features_parquet} does not exist, skipping...')
                continue 
            else:
                sc_features = pd.concat([
                    sc_features, 
                    pd.read_parquet(
                        sc_features_parquet,
                        columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']
                    )
                ])

        ## Load data
        pds = PatchDataset(
            _loaddata_csv=loaddata_condition_df,
            _sc_feature=sc_features,
            _input_channel_keys=INPUT_CHANNEL_NAMES,
            _target_channel_keys=TARGET_CHANNEL_NAMES,
            _input_transform=PixelDepthTransform(src_bit_depth=16, target_bit_depth=8, _always_apply=True),
            _target_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),
            patch_size=PATCH_SIZE,
            verbose=False,
            patch_generation_method="random_cell",
            n_expected_patches_per_img=50,
            patch_generation_random_seed=42
        )

        ## Group evaluation by channel (mostly unecessary)
        for target_channel_name, df in mlflow_results[confluence_group_name].groupby('params.channel_name'):
            
            CHANNEL_VALIDATION_PATH = CONFLUENCE_VALIDATION_PATH / target_channel_name
            CHANNEL_VALIDATION_PATH.mkdir(exist_ok=True)
            
            EXAMPLE_PATCH_PLOT_PATH = CHANNEL_VALIDATION_PATH / f'example_patch_plots'
            EXAMPLE_PATCH_PLOT_PATH.mkdir(exist_ok=True)
            
            pds.set_input_channel_keys(INPUT_CHANNEL_NAMES)
            pds.set_target_channel_keys([target_channel_name])
            _, targets = next(iter(DataLoader(pds, batch_size=len(pds))))
            
            ## Iterate over models
            for _, row in df.iterrows():
                model_run_id = row['run_id']

                METRICS_FILE_PATH = CHANNEL_VALIDATION_PATH / f'{model_run_id}_metrics.csv'
                if METRICS_FILE_PATH.exists():
                    metrics_df = pd.read_csv(METRICS_FILE_PATH)
                    all_metrics_df = pd.concat([all_metrics_df, metrics_df])
                    continue

                model_uri = row['artifact_uri']
                model_weight_path = pathlib.Path(mlflow.artifacts.download_artifacts(artifact_uri=model_uri)) /\
                    'models' / 'best_model_weights.pth'
                if not model_weight_path.exists():
                    print(f"Model weight not found for run {model_run_id}, skipping ...")
                    continue
                model_depth = int(row['params.depth'])
                model_channel_name = row['params.channel_name']

                model = FNet(depth=model_depth)
                try:
                    model.load_state_dict(torch.load(model_weight_path, weights_only=True))
                except:
                    print(f"Fail to load model weight for run {model_run_id}, skipping ...")
                    continue
                model.to(EVAL_DEVICE)
                
                predictions = predict_image(
                    dataset = pds,
                    model = model,
                    device = EVAL_DEVICE
                )

                metrics_df = evaluate_per_image_metric(
                    predictions=predictions,
                    targets=targets,
                    metrics=EVAL_METRICS
                )
                metrics_mean = metrics_df.mean()

                metrics_df['datasplit'] = datasplit
                metrics_df['confluence'] = confluence_group_name
                params_values = {key: value for key, value in row.items() if key.startswith('params.')}
                for param, value in params_values.items():
                    metrics_df[param] = value
                metrics_df.to_csv(METRICS_FILE_PATH)
                all_metrics_df = pd.concat([all_metrics_df, metrics_df])
                
                metrics_mean_str = '_'.join([f"{key}={value:.2f}" for key, value in metrics_mean.items()])
                params_str = '_'.join([f"{key.replace('params.','')}={value}" for key, value in params_values.items()])
                plot_patches(
                    dataset=pds,
                    n_patches=5,
                    model=model,
                    random_seed=42,
                    device=EVAL_DEVICE,
                    metrics=EVAL_METRICS,
                    show_plot=False,
                    save_path=EXAMPLE_PATCH_PLOT_PATH / f'{metrics_mean_str}_{params_str}.png'
                )                