# This notebook trains FNet model on two datasets (high and low confluence level) of U2-OS cell painting image data with optimal hyper-parameters

In [1]:
import pathlib
import sys
import yaml
import random

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import mlflow
import optuna
import joblib

## Read config

In [2]:
with open(pathlib.Path('.').absolute().parent.parent / "config.yml", "r") as file:
    config = yaml.safe_load(file)

## Import virtual_stain_flow software 

In [3]:
sys.path.append(config['paths']['software_path'])
print(str(pathlib.Path('.').absolute().parent.parent))

## Dataset
from virtual_stain_flow.datasets.PatchDataset import PatchDataset
from virtual_stain_flow.datasets.CachedDataset import CachedDataset

## FNet training
from virtual_stain_flow.models.fnet import FNet
from virtual_stain_flow.trainers.Trainer import Trainer

## wGaN training
from virtual_stain_flow.models.unet import UNet
from virtual_stain_flow.models.discriminator import GlobalDiscriminator
from virtual_stain_flow.trainers.WGaNTrainer import WGaNTrainer

## wGaN losses
from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss
from virtual_stain_flow.losses.DiscriminatorLoss import DiscriminatorLoss
from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss

from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize
from virtual_stain_flow.transforms.PixelDepthTransform import PixelDepthTransform

## Metrics
from virtual_stain_flow.metrics.MetricsWrapper import MetricsWrapper
from virtual_stain_flow.metrics.PSNR import PSNR
from virtual_stain_flow.metrics.SSIM import SSIM

## callback
from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger
from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePatchPlot

/home/weishanli/Waylab/pediatric_cancer_atlas_analysis


## Define paths and other train parameters

In [4]:
## Optimization results 
OPTUNA_RESULTS_DIR = pathlib.Path('.').absolute().parent.parent / \
    '1.model_optimization' / \
    '1.2.optimize_fnet_by_confluence' / \
    'optuna_joblib'
assert OPTUNA_RESULTS_DIR.exists()

## Loaddata for train
LOADDATA_FILE_PATH = pathlib.Path('.').absolute().parent.parent \
    / '0.data_preprocessing' / 'data_split_loaddata' / 'loaddata_train.csv'
assert LOADDATA_FILE_PATH.exists()
SC_FEATURES_DIR = pathlib.Path(config['paths']['sc_features_path'])

## Train Logging/Weight output directory
MLFLOW_DIR = pathlib.Path('.').absolute() / 'mlflow'
MLFLOW_DIR.mkdir(parents=True, exist_ok=True)

## Train intermediate plot otuput directory
PLOT_DIR = pathlib.Path('.').absolute() / 'plots'
PLOT_DIR.mkdir(parents=True, exist_ok=True)

## Basic data generation, model convolutional depth and max epoch definition
PATCH_SIZE = 256
EPOCHS = 1_000

## Channels for input and target are read from config
INPUT_CHANNEL_NAMES = config['data']['input_channel_keys']
TARGET_CHANNEL_NAMES = config['data']['target_channel_keys']

In [5]:
channel_name = "OrigAGP"
confluence_group_name = "high_confluence"
study_path = OPTUNA_RESULTS_DIR / \
            f"FNet_optimize_{channel_name}_{confluence_group_name}.joblib"
if study_path.exists():
    study = joblib.load(study_path)
else:
    assert False

In [6]:
study.best_params

{'lr': 0.0005690482786963295,
 'beta0': 0.8951555358866697,
 'beta1': 0.9770467004930141,
 'batch_size': 32,
 'patience': 15,
 'conv_depth': 4}

In [7]:
lr = study.best_params['lr']
beta0 = study.best_params['beta0']
beta1 = study.best_params['beta1']
patience = study.best_params['patience']
batch_size = study.best_params['batch_size']
conv_depth = study.best_params['conv_depth']

## Defines how the train data will be divided to train models on two levels of confluence

In [8]:
DATA_GROUPING = {
    'high_confluence': {
        'seeding_density': [12_000, 8_000]
    },
    'low_confluence': {
        'seeding_density': [4_000, 2_000, 1_000]
    }
}

## Set up metrics

In [9]:
metric_fns = {
        "mse_loss": MetricsWrapper(_metric_name='mse', module=torch.nn.MSELoss()),
        "ssim_loss": SSIM(_metric_name="ssim"),
        "psnr_loss": PSNR(_metric_name="psnr"),
    }

## Train seed and device settings for reproducibility

In [10]:
train_seed = 42
TRAIN_DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

random.seed(train_seed)
np.random.seed(train_seed)

## Train with best best hyper-parameters

In [11]:
loaddata_df = pd.read_csv(LOADDATA_FILE_PATH)

## Iterate over confluence (as separate optimizations were performed for each confluence level)
for confluence_group_name, conditions in DATA_GROUPING.items():

    ## Load dataset
    loaddata_condition_df = loaddata_df.copy()
    for condition, values in conditions.items():
        loaddata_condition_df = loaddata_condition_df[
            loaddata_condition_df[condition].isin(values)
        ]
    ## Collect corresponding single cell level features
    sc_features = pd.DataFrame()
    for plate in loaddata_condition_df['Metadata_Plate'].unique():
        sc_features_parquet = SC_FEATURES_DIR / f'{plate}_sc_normalized.parquet'
        if not sc_features_parquet.exists():
            print(f'{sc_features_parquet} does not exist, skipping...')
            continue 
        else:
            sc_features = pd.concat([
                sc_features, 
                pd.read_parquet(
                    sc_features_parquet,
                    columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']
                )
            ])
    ## Create patch dataset
    pds = PatchDataset(
        _loaddata_csv=loaddata_condition_df,
        _sc_feature=sc_features,
        _input_channel_keys=INPUT_CHANNEL_NAMES,
        _target_channel_keys=TARGET_CHANNEL_NAMES,
        _input_transform=PixelDepthTransform(src_bit_depth=16, target_bit_depth=8, _always_apply=True),
        _target_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),
        patch_size=PATCH_SIZE,
        verbose=False,
        patch_generation_method="random_cell",
        n_expected_patches_per_img=50,
        patch_generation_random_seed=42
    )

    ## Make folder for per confluence plotting
    CONF_PLOT_DIR =  PLOT_DIR / confluence_group_name
    CONF_PLOT_DIR.mkdir(exist_ok=True)

    for channel_name in TARGET_CHANNEL_NAMES:

        ## Make folder for per channel plotting under confluence
        CHANNEL_CONF_PLOT_DIR =  CONF_PLOT_DIR / channel_name
        CHANNEL_CONF_PLOT_DIR.mkdir(exist_ok=True)

        ## Cache single input/single target dataset to speed up training
        pds.set_input_channel_keys(INPUT_CHANNEL_NAMES)
        pds.set_target_channel_keys(channel_name)
        cds = CachedDataset(
            dataset=pds,
            prefill_cache=True
        )

        print("Beginning FNet model training for channel: " \
              f"{channel_name} and group {confluence_group_name}")
        
        ## Access relevant optuna study
        study_path = OPTUNA_RESULTS_DIR / \
            f"FNet_optimize_{channel_name}_{confluence_group_name}.joblib"
        if study_path.exists():
            study = joblib.load(study_path)
        else:
            print("Optimization result not found for channel: "\
                  f"{channel_name} and group {confluence_group_name}, skipping ...")
        
        ## Retrieve best parameters
        lr = study.best_params['lr']
        beta0 = study.best_params['beta0']
        beta1 = study.best_params['beta1']
        betas = (beta0, beta1)
        patience = study.best_params['patience']
        batch_size = study.best_params['batch_size']
        conv_depth = study.best_params['conv_depth']

        torch.manual_seed(train_seed)
        torch.cuda.manual_seed(train_seed)
        torch.cuda.manual_seed_all(train_seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

        model = FNet(depth=conv_depth)
        optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas)

        mlflow_logger_callback = MlflowLogger(
            name='mlflow_logger',
            mlflow_uri=MLFLOW_DIR / 'mlruns',
            mlflow_experiment_name=f'FNet_train_{confluence_group_name}',
            mlflow_start_run_args={'run_name': f'FNet_train_{confluence_group_name}_{channel_name}', 'nested': True},
            mlflow_log_params_args={
                "lr": lr,
                "beta0": beta0,
                "beta1": beta1,
                "depth": conv_depth,
                "patch_size": PATCH_SIZE,
                "batch_size": batch_size,
                "epochs": EPOCHS,
                "patience": patience,
                "channel_name": channel_name,
            },
        )

        plot_callback = IntermediatePatchPlot(
            name='plotter',
            path=CHANNEL_CONF_PLOT_DIR,
            dataset=pds, # give it the patch dataset as opposed to the cached dataset
            plot_metrics=[SSIM(_metric_name='ssim'), PSNR(_metric_name='psnr')],
            figsize=(20, 25),
            every_n_epochs=5,
            show_plot=False,
        )
        
        trainer = Trainer(
            model = model,
            optimizer = optimizer,
            backprop_loss = torch.nn.L1Loss(), # MAE loss for backpropagation
            dataset = pds,
            batch_size = batch_size,
            epochs = EPOCHS,
            patience = patience,
            callbacks=[mlflow_logger_callback, plot_callback],
            metrics=metric_fns,
            device = 'cuda',
            early_termination_metric='L1Loss'
        )

        trainer.train()

        del model
        del trainer

        del mlflow_logger_callback
        del plot_callback

Beginning FNet model training for channel: OrigDNA and group high_confluence


2025/02/26 23:18:41 INFO mlflow.tracking.fluent: Experiment with name 'FNet_train_high_confluence' does not exist. Creating a new experiment.


Early termination at epoch 41 with best validation metric 0.019157148897647858
Beginning FNet model training for channel: OrigER and group high_confluence
Early termination at epoch 48 with best validation metric 0.022135975903698375
Beginning FNet model training for channel: OrigAGP and group high_confluence
Early termination at epoch 60 with best validation metric 0.0030946837339018074
Beginning FNet model training for channel: OrigMito and group high_confluence
Early termination at epoch 23 with best validation metric 0.06919396238831374
Beginning FNet model training for channel: OrigRNA and group high_confluence
Early termination at epoch 29 with best validation metric 0.019445359563598268


2025/02/27 00:07:37 INFO mlflow.tracking.fluent: Experiment with name 'FNet_train_low_confluence' does not exist. Creating a new experiment.


Beginning FNet model training for channel: OrigDNA and group low_confluence
Early termination at epoch 34 with best validation metric 0.0074959140712101205
Beginning FNet model training for channel: OrigER and group low_confluence
Early termination at epoch 31 with best validation metric 0.015339045001095846
Beginning FNet model training for channel: OrigAGP and group low_confluence
Early termination at epoch 76 with best validation metric 0.0017006185003801395
Beginning FNet model training for channel: OrigMito and group low_confluence
Early termination at epoch 38 with best validation metric 0.0391944024179663
Beginning FNet model training for channel: OrigRNA and group low_confluence
Early termination at epoch 53 with best validation metric 0.010584966818753042
