## This notebook carries out the training of wGAN gp model on two datasets (high and low confluence level) of U2-OS cell painting image data 

In [1]:
import random
import pathlib
import sys
import yaml

import pandas as pd
import torch
import torch.optim as optim

## Read config

In [2]:
with open(pathlib.Path('.').absolute().parent.parent / "config.yml", "r") as file:
    config = yaml.safe_load(file)

## Import virtual_stain_flow software 

In [3]:
sys.path.append(config['paths']['software_path'])
print(str(pathlib.Path('.').absolute().parent.parent))

## Dataset
from virtual_stain_flow.datasets.PatchDataset import PatchDataset
from virtual_stain_flow.datasets.CachedDataset import CachedDataset

## wGaN training
from virtual_stain_flow.models.unet import UNet
from virtual_stain_flow.models.discriminator import GlobalDiscriminator
from virtual_stain_flow.trainers.WGANTrainer import WGANTrainer

## wGaN losses
from virtual_stain_flow.losses.GradientPenaltyLoss import GradientPenaltyLoss
from virtual_stain_flow.losses.DiscriminatorLoss import WassersteinLoss
from virtual_stain_flow.losses.GeneratorLoss import GeneratorLoss

from virtual_stain_flow.transforms.MinMaxNormalize import MinMaxNormalize
from virtual_stain_flow.transforms.PixelDepthTransform import PixelDepthTransform

## Metrics
from virtual_stain_flow.metrics.MetricsWrapper import MetricsWrapper
from virtual_stain_flow.metrics.PSNR import PSNR
from virtual_stain_flow.metrics.SSIM import SSIM

## callback
from virtual_stain_flow.callbacks.MlflowLogger import MlflowLogger
from virtual_stain_flow.callbacks.IntermediatePlot import IntermediatePlot

/home/weishanli/Waylab/pediatric_cancer_atlas_analysis


  check_for_updates()


## Define paths and other train parameters

In [4]:
## Loaddata for train
LOADDATA_FILE_PATH = pathlib.Path('.').absolute().parent.parent \
    / '0.data_preprocessing' / 'data_split_loaddata' / 'loaddata_train.csv'
assert LOADDATA_FILE_PATH.exists()

LOADDATA_HELDOUT_FILE_PATH = pathlib.Path('.').absolute().parent.parent \
    / '0.data_preprocessing' / 'data_split_loaddata' / 'loaddata_heldout.csv'
assert LOADDATA_HELDOUT_FILE_PATH.exists(), f"Directory not found: {LOADDATA_HELDOUT_FILE_PATH}"

SC_FEATURES_DIR = pathlib.Path(config['paths']['sc_features_path'])

## Train Logging/Weight output directory
MLFLOW_DIR = pathlib.Path('.').absolute() / 'mlflow'
MLFLOW_DIR.mkdir(parents=True, exist_ok=True)

## Train intermediate plot otuput directory
PLOT_DIR = pathlib.Path('.').absolute() / 'train_plots'
PLOT_DIR.mkdir(parents=True, exist_ok=True)

## Basic data generation, model convolutional depth and max epoch definition
PATCH_SIZE = 256
EPOCHS = 1_000

## Channels for input and target are read from config
INPUT_CHANNEL_NAMES = config['data']['input_channel_keys']
TARGET_CHANNEL_NAMES = config['data']['target_channel_keys']

## Defines how the train data will be divided to train models on two levels of confluence

In [5]:
DATA_GROUPING = {
    'high_confluence': {
        'seeding_density': [12_000, 8_000]
    },
    'low_confluence': {
        'seeding_density': [4_000, 2_000, 1_000]
    }
}

## Set up metrics

In [6]:
metric_fns = {
    "mse_loss": MetricsWrapper(_metric_name='mse', module=torch.nn.MSELoss()),
    "ssim_loss": SSIM(_metric_name="ssim"),
    "psnr_loss": PSNR(_metric_name="psnr"),
}

## Define hyperparams as described by Cross-Zamirski
https://www.nature.com/articles/s41598-022-12914-x

In [7]:
GEN_DEPTHS = [4, 5]

GEN_DEPTH = 4
GEN_OPTIM_LR = 2e-4
GEN_OPTIM_BETA0 = 0.0
GEN_OPTIM_BETA1 = 0.9
GEN_OPTIM_WEIGHT_DECAY = 0
GEN_UPDATE_FREQ = 5

DISC_DEPTH = 4
DISC_N_FILTERS = 64
DISC_OPTIM_LR = 2e-4
DISC_OPTIM_BETA0 = 0.0
DISC_OPTIM_BETA1 = 0.9
DISC_OPTIM_WEIGHT_DECAY = 1e-3
DISC_UPDATE_FREQ = 1 # discriminator updated every epoch

GP_WEIGHT = 10.0

BATCH_SIZE = 16 # reduce vRAM usage
PATIENCE = 20 # fixed patience for early stopping

## Create patched dataset from heldout data for use with plotting predictions during training
Heldout is not used as training data, just visualization

In [None]:
loaddata_heldout_df = pd.read_csv(LOADDATA_HELDOUT_FILE_PATH)
## Retrieve relevant sc feature information
sc_features = pd.DataFrame()
for plate in loaddata_heldout_df['Metadata_Plate'].unique():
    sc_features_parquet = SC_FEATURES_DIR / f'{plate}_sc_normalized.parquet'
    if not sc_features_parquet.exists():
        print(f'{sc_features_parquet} does not exist, skipping...')
        continue 
    else:
        sc_features = pd.concat([
            sc_features, 
            pd.read_parquet(
                sc_features_parquet,
                columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']
            )
        ])

## Generate multi-channel patch dataset for plotting
pds_heldout = PatchDataset(
        _loaddata_csv=loaddata_heldout_df,
        _sc_feature=sc_features,
        _input_channel_keys=INPUT_CHANNEL_NAMES,
        _target_channel_keys=TARGET_CHANNEL_NAMES,
        _input_transform=PixelDepthTransform(src_bit_depth=16, target_bit_depth=8, _always_apply=True),
        _target_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),
        patch_size=PATCH_SIZE,
        verbose=False,
        patch_generation_method="random_cell",
        n_expected_patches_per_img=5,
        patch_generation_random_seed=42
    )

## Generate list of indice to plot
n_patches = len(pds_heldout)
random.seed(42)
visualization_patch_indices = random.sample(range(n_patches), 5)

## Train models per confluence level per channel

In [None]:
TRAIN_DEVICE = 'cuda'

import gc
def free_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

## Train data split loaddata
loaddata_df = pd.read_csv(LOADDATA_FILE_PATH)

## Iterate over confluence group
for confluence_group_name, conditions in DATA_GROUPING.items():

    ## Subset Loaddata to confluence level
    loaddata_condition_df = loaddata_df.copy()
    for condition, values in conditions.items():
        loaddata_condition_df = loaddata_condition_df[
            loaddata_condition_df[condition].isin(values)
        ]

    ## Retrieve relevant sc feature information
    sc_features = pd.DataFrame()
    for plate in loaddata_condition_df['Metadata_Plate'].unique():
        sc_features_parquet = SC_FEATURES_DIR / f'{plate}_sc_normalized.parquet'
        if not sc_features_parquet.exists():
            print(f'{sc_features_parquet} does not exist, skipping...')
            continue 
        else:
            sc_features = pd.concat([
                sc_features, 
                pd.read_parquet(
                    sc_features_parquet,
                    columns=['Metadata_Plate', 'Metadata_Well', 'Metadata_Site', 'Metadata_Cells_Location_Center_X', 'Metadata_Cells_Location_Center_Y']
                )
            ])

    ## Generate multi-channel patch dataset for confluence group
    pds = PatchDataset(
        _loaddata_csv=loaddata_condition_df,
        _sc_feature=sc_features,
        _input_channel_keys=INPUT_CHANNEL_NAMES,
        _target_channel_keys=TARGET_CHANNEL_NAMES,
        _input_transform=PixelDepthTransform(src_bit_depth=16, target_bit_depth=8, _always_apply=True),
        _target_transform=MinMaxNormalize(_normalization_factor=(2 ** 16) - 1, _always_apply=True),
        patch_size=PATCH_SIZE,
        verbose=False,
        patch_generation_method="random_cell",
        n_expected_patches_per_img=50,
        patch_generation_random_seed=42
    )

    ## Make folder for per confluence plotting
    CONF_PLOT_DIR =  PLOT_DIR / confluence_group_name
    CONF_PLOT_DIR.mkdir(exist_ok=True)

    ## Train a model for each target channel
    for channel_name in TARGET_CHANNEL_NAMES:

        ## Make folder for per channel plotting under confluence
        CHANNEL_CONF_PLOT_DIR =  CONF_PLOT_DIR / channel_name
        CHANNEL_CONF_PLOT_DIR.mkdir(exist_ok=True)

        ## Cache single input/single target dataset to speed up training
        pds.set_input_channel_keys(INPUT_CHANNEL_NAMES)
        pds.set_target_channel_keys(channel_name)
        cds = CachedDataset(
            dataset=pds,
            prefill_cache=True
        )

        ## Configure input/target channel for heldout dataset for plotting
        pds_heldout.set_input_channel_keys(INPUT_CHANNEL_NAMES)
        pds_heldout.set_target_channel_keys(channel_name)

        ## Iterate over the hardcoded generator depths and train a model for each
        for GEN_DEPTH in GEN_DEPTHS:

            DEPTH_CHANNEL_CONF_PLOT_DIR = CHANNEL_CONF_PLOT_DIR / channel_name
            DEPTH_CHANNEL_CONF_PLOT_DIR.mkdir(exist_ok=True)

            print(f"Beginning prototype training wGAN-gp for channel: {channel_name} for {confluence_group_name} with generator depth: {GEN_DEPTH}")

            generator = UNet(
                n_channels=1,
                n_classes=1,
                depth=GEN_DEPTH,
                bilinear=False
            )

            discriminator = GlobalDiscriminator(
                n_in_channels = 2, # 1 input brightfield + 1 target channel
                n_in_filters = DISC_N_FILTERS,
                _conv_depth = DISC_DEPTH,
                _pool_before_fc = True
            )

            generator_optimizer = optim.Adam(generator.parameters(), 
                                            lr=GEN_OPTIM_LR, 
                                            betas=(GEN_OPTIM_BETA0, GEN_OPTIM_BETA1),
                                            weight_decay=GEN_OPTIM_WEIGHT_DECAY)
            discriminator_optimizer = optim.Adam(discriminator.parameters(), 
                                                lr=DISC_OPTIM_LR, 
                                                betas=(DISC_OPTIM_BETA0, DISC_OPTIM_BETA1),
                                                weight_decay=DISC_OPTIM_WEIGHT_DECAY)

            gp_loss = GradientPenaltyLoss(
                _metric_name='gp_loss',
                discriminator=discriminator,
                weight=GP_WEIGHT,
            )

            gen_loss = GeneratorLoss(
                _metric_name='gen_loss'
            )

            disc_loss = WassersteinLoss(
                _metric_name='disc_loss'
            )

            params = {
                        # generator optimizer hyperparameters
                        'gen_optim_lr': GEN_OPTIM_LR,
                        'gen_optim_beta0': GEN_OPTIM_BETA0,
                        'gen_optim_beta1': GEN_OPTIM_BETA1,
                        'gen_optim_weight_decay': GEN_OPTIM_WEIGHT_DECAY,
                        # generator hyperparameters
                        'gen_update_freq': GEN_UPDATE_FREQ,
                        'gen_depth': GEN_DEPTH,
                        # discriminator optimizer hyperparameters
                        'disc_n_filters': DISC_N_FILTERS,
                        'disc_optim_beta0': DISC_OPTIM_BETA0,
                        'disc_optim_beta1': DISC_OPTIM_BETA1,
                        'disc_optim_weight_decay': DISC_OPTIM_WEIGHT_DECAY,
                        # discriminator hyperparameters
                        'disc_update_freq': DISC_UPDATE_FREQ,
                        'disc_depth': DISC_DEPTH,
                        # gradient penalty weight
                        'gp_loss_weight': GP_WEIGHT,
                        # dataset hyperparameters
                        'patch_size': PATCH_SIZE,
                        'channel_name': channel_name,
                        'confluence_group': confluence_group_name,
                        # data loading hyperparameter(s)
                        'batch_size': BATCH_SIZE,
                        # training hyperparameters
                        "patience": PATIENCE,
                        "epochs": EPOCHS,                        
                    }

            mlflow_logger_callback = MlflowLogger(
                    name='mlflow_logger',
                    mlflow_uri=MLFLOW_DIR / 'mlruns',
                    mlflow_experiment_name=f'wGAN_gp_prototype_train_{confluence_group_name}',
                    mlflow_start_run_args={'run_name': f'wGAN_gp_prototype_train_{confluence_group_name}_{channel_name}_{GEN_DEPTH}', 'nested': True},
                    mlflow_log_params_args=params,
                )

            plot_callback = IntermediatePlot(
                name='plotter',
                path=DEPTH_CHANNEL_CONF_PLOT_DIR,
                dataset=pds_heldout,
                indices=visualization_patch_indices, # every model being trained will have the same visualization patch indices
                plot_metrics=[SSIM(_metric_name='ssim'), PSNR(_metric_name='psnr')],
                figsize=(20, 25),
                every_n_epochs=5,
                show_plot=False,
            )

            wgan_trainer = WGANTrainer(
                dataset=cds,
                batch_size=BATCH_SIZE,
                epochs=EPOCHS,
                patience=PATIENCE, # setting this to prevent unwanted early termination here
                device=TRAIN_DEVICE,
                generator=generator,
                discriminator=discriminator,
                gen_optimizer=generator_optimizer,
                disc_optimizer=discriminator_optimizer,
                generator_loss_fn=gen_loss,
                discriminator_loss_fn=disc_loss,
                gradient_penalty_fn=gp_loss,
                discriminator_update_freq=DISC_UPDATE_FREQ,
                generator_update_freq=GEN_UPDATE_FREQ,
                callbacks=[mlflow_logger_callback, plot_callback],
                metrics=metric_fns,
                early_termination_metric='L1Loss'
            )

            wgan_trainer.train()

            del generator
            del discriminator
            del wgan_trainer

            del mlflow_logger_callback
            del plot_callback


Beginning training wGAN-gp for channel: OrigDNA for high_confluence


2025/02/26 14:13:39 INFO mlflow.tracking.fluent: Experiment with name 'wGAN_gp_prototype_train_high_confluence' does not exist. Creating a new experiment.


Early termination at epoch 27 with best validation metric -0.210029647900508
Beginning training wGAN-gp for channel: OrigER for high_confluence
Early termination at epoch 30 with best validation metric -0.023256127650921162
Beginning training wGAN-gp for channel: OrigAGP for high_confluence
Early termination at epoch 28 with best validation metric -0.008496610017923208
Beginning training wGAN-gp for channel: OrigMito for high_confluence
Early termination at epoch 37 with best validation metric -0.3187789022922516
Beginning training wGAN-gp for channel: OrigRNA for high_confluence
Early termination at epoch 32 with best validation metric -0.0027225567744328426


2025/02/26 15:20:42 INFO mlflow.tracking.fluent: Experiment with name 'wGAN_gp_prototype_train_low_confluence' does not exist. Creating a new experiment.


Beginning training wGAN-gp for channel: OrigDNA for low_confluence
