## Imports

In [None]:
!pip install -U "jax[tpu]" optuna-dashboard plotly nbformat optuna grain clu jdc munch omegaconf aim

Collecting jax[tpu]
  Downloading jax-0.6.1-py3-none-any.whl (2.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m23.5 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hCollecting optuna-dashboard
  Downloading optuna_dashboard-0.18.0-py3-none-any.whl (8.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m75.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hCollecting plotly
  Downloading plotly-6.1.2-py3-none-any.whl (16.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.3/16.3 MB[0m [31m57.4 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting optuna
  Downloading optuna-4.3.0-py3-none-any.whl (386 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m386.6/386.6 kB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting grain
  Downloading grain-0.2.9-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (479 kB)
[2K     [90m━━━━━━━━━━━━━━━━━

In [None]:
# pip install -U "jax[cuda12]"

In [None]:
import numpy as np
from numpy.typing import NDArray
from typing import List, Any, Tuple, Dict
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.image as mpimg
from pathlib import Path
import tensorflow as tf

from tensorflow.keras.preprocessing.image import ImageDataGenerator

from mpl_toolkits.axes_grid1 import ImageGrid
from math import ceil
import jax, jax.numpy as jnp, optax, jax.random as jr

from flax import linen as nn
from flax.training import train_state  # Useful dataclass to keep train state

import gc

import optuna
from optuna.visualization import plot_contour, plot_param_importances, plot_optimization_history, plot_slice, plot_parallel_coordinate

import logging
import sys

from omegaconf import OmegaConf

import torch
from functools import partial
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from jax.sharding import PartitionSpec as P, NamedSharding
import random
import orbax.checkpoint as ocp
from etils import epath
import json
from grain.python import ShardByJaxProcess, Batch

from clu import metrics
from flax import struct     
import glob

import time
import threading
from optuna_dashboard import wsgi
import optuna
from wsgiref.simple_server import make_server

import jdc
%matplotlib inline

## Project Topic: **CNN Cancer Detection Kaggle Mini-Project**

### Exploratory Data Analysis

In [None]:
WRK_DIR = '/kaggle/input/histopathologic-cancer-detection'
TRAIN_DIR = f'{WRK_DIR}/train/'
TEST_DIR = f'{WRK_DIR}/test/'

In [None]:
df_train = pd.read_csv('/kaggle/input/histopathologic-cancer-detection/train_labels.csv')
df_test = pd.DataFrame({'path': glob.glob(os.path.join(TEST_DIR, '*.tif'))})
# df_test['id'] = df_test['path'].str.extract(r'([^//]+).tif$')
LABEL_MAPPER  = {0: '0: No Cancer', 1: '1: Cancer'}

In [None]:
df_train['image_path'] = TRAIN_DIR + df_train['id']+'.tif'

In [None]:
df_train.drop('id', inplace=True, axis=1)
df_train['label'] = df_train['label'].astype(str)

In [None]:
df_train.head()

In [None]:
df_test.head()

In [None]:
print(df_train.info())

In [None]:
print(df_test.info())

In [None]:
df_train.shape

In [None]:
# number of files in the train folder = size of train dataset
train_files = glob.glob(os.path.join(TRAIN_DIR, '*.tif'))
print(len(train_files))

In [None]:
# number of files in the test folder = size of test dataset
test_files = glob.glob(os.path.join(TEST_DIR, '*.tif'))
print(len(test_files))

In [None]:
fig,ax=plt.subplots(1,2,figsize=(10,5))
df_train.label.value_counts().plot.pie(explode=[0,0.1],autopct='%1.1f%%',ax=ax[0],shadow=True, labels=LABEL_MAPPER.values())
ax[0].set_ylabel('')
sns.countplot(df_train, x='label', ax=ax[1])
fig.suptitle("Class Distribution")
plt.show()

As seen from the plot above, the dataset is mildly imbalanced$^\href{https://developers.google.com/machine-learning/crash-course/overfitting/imbalanced-datasets}{1}$ with the majority class being 0 (No cancer - Negative class) and minority class as 1 (presence of cancer - Postive class).

In [None]:
#https://matplotlib.org/stable/gallery/axes_grid1/demo_axes_grid2.html#sphx-glr-gallery-axes-grid1-demo-axes-grid2-py
def add_inner_title(ax, title, loc, **kwargs):
    from matplotlib.offsetbox import AnchoredText
    from matplotlib.patheffects import withStroke
    prop = dict(path_effects=[withStroke(foreground='w', linewidth=2)],
                size=plt.rcParams['legend.fontsize'])
    at = AnchoredText(title, loc=loc, prop=prop,
                      pad=0., borderpad=0.5,
                      frameon=False, **kwargs)
    ax.add_artist(at)
    return at

In [None]:
def show_images(dirs: List[Tuple[str, int]], preds: List[int] = []) -> None:
    assert isinstance(dirs, list) and dirs, 'must be a list of valid image paths/files'
    if len(preds)>0:
        assert isinstance(preds, list) and len(preds)==len(dirs), 'if provided, preds must be list of predictions and must be of same len with dirs'
    fig = plt.figure(figsize=(10., 10.))
    mm = [ (mpimg.imread(img), im_title) for img, im_title in dirs if os.path.exists(img) ]
    grid = ImageGrid(fig, 111, 
                     nrows_ncols=(ceil(len(dirs)/3), 3),
                     axes_pad=0.05, label_mode="all"
                 )
    for i, (ax, (im, im_title)) in enumerate(zip(grid, mm)):
        ax.imshow(im)
        add_inner_title(ax,  f'true: {LABEL_MAPPER[im_title]}', loc='upper left')
        if preds:
            add_inner_title(ax, f'\npredicted: {preds[i]}', loc='upper left')
        
    plt.show()

In [None]:
show_images([(f'{TRAIN_DIR}f38a6374c348f90b587e046aac6079959adf3835.tif', 0), 
             (f'{TRAIN_DIR}f38a6374c348f90b587e046aac6079959adf3835.tif', 0), 
             (f'{TRAIN_DIR}f38a6374c348f90b587e046aac6079959adf3835.tif', 1),
            (f'{TRAIN_DIR}f38a6374c348f90b587e046aac6079959adf3835.tif', 0)])

## Modelling

### Dataset

In [None]:
BATCH_SIZE = 32
img_height = 96
img_width = 96
MAX_EPOCHS = 2

In [None]:
import flax
@flax.struct.dataclass
class F1_score(metrics.Metric):
    true_pos: jnp.array
    pred_pos: jnp.array
    actual_pos: jnp.array

    @classmethod
    def empty(cls):
        return cls(true_pos=jnp.array(0, jnp.int32), pred_pos=jnp.array(0, jnp.int32), actual_pos=jnp.array(0, jnp.int32))
    
    @classmethod
    def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, **_) -> metrics.Metric:
        assert logits.shape[-1] == 2, "Expected binary logits."
        preds = logits.argmax(axis=-1)
        return cls(
            true_pos=((preds == 1) & (labels == 1)).sum(), # predicted and ground truth is 1- positive class
            pred_pos=(preds == 1).sum(), # sum of predicted positives
            actual_pos=(labels == 1).sum() # sum of ground truth positive
        )
        
    def merge(self, other: metrics.Metric) -> metrics.Metric:
        return type(self)(
            true_pos=self.true_pos + other.true_pos,
            pred_pos=self.pred_pos + other.pred_pos,
            actual_pos=self.actual_pos + other.actual_pos
        )
    def compute(self): # f1_score = 2 / (1/precision + 1/recall)
        recall  = self.true_pos / self.actual_pos
        precision = self.true_pos / self.pred_pos
        return 2 / (1/recall + 1/precision)

In [None]:
'''
Mathews Correlation Coefficient (MCC)
mcc =  (tp * tn -  fp * fn)/sqrt( (pred_pos)(actual_pos)(actual_negative)(pred_negative))
'''
@flax.struct.dataclass
class MCC(metrics.Metric):
    tp: jnp.array
    tn: jnp.array
    fp: jnp.array
    fn: jnp.array
    pred_pos: jnp.array
    pred_neg: jnp.array
    actual_pos: jnp.array
    actual_neg: jnp.array
    

    @classmethod
    def empty(cls):
        return cls(tp=jnp.array(0, jnp.int32), tn=jnp.array(0, jnp.int32),
                   fp=jnp.array(0, jnp.int32), fn=jnp.array(0, jnp.int32),
                   pred_pos=jnp.array(0, jnp.int32), pred_neg=jnp.array(0, jnp.int32), 
                   actual_pos=jnp.array(0, jnp.int32), actual_neg=jnp.array(0, jnp.int32))
    
    @classmethod
    def from_model_output(cls, *, logits: jnp.array, labels: jnp.array, **_) -> metrics.Metric:
        assert logits.shape[-1] == 2, "Expected binary logits."
        preds = logits.argmax(axis=-1)
        return cls(
            tp=((preds == 1) & (labels == 1)).sum(), # predicted and ground truth is 1- positive class
            tn=((preds == 0) & (labels == 0)).sum(), # predicted and ground truth is 1- positive class
            fp=((preds == 1) & (labels == 0)).sum(), # predicted and ground truth is 1- positive class
            fn=((preds == 0) & (labels == 1)).sum(), # predicted and ground truth is 1- positive class
            pred_pos=(preds == 1).sum(), # sum of predicted positives
            pred_neg=(preds == 0).sum(), # sum of predicted positives
            actual_pos=(labels == 1).sum(), # sum of ground truth positive
            actual_neg=(labels == 0).sum() # sum of ground truth positive
        )
        
    def merge(self, other: metrics.Metric) -> metrics.Metric:
        return type(self)(
            tp=self.tp + other.tp,
            tn=self.tn + other.tn,
            fp=self.fp + other.fp,
            fn=self.fn + other.fn,
            pred_pos=self.pred_pos + other.pred_pos,
            pred_neg=self.pred_neg + other.pred_neg,
            actual_pos=self.actual_pos + other.actual_pos,
            actual_neg=self.actual_neg + other.actual_neg,
        )
    def compute(self): # f1_score = 2 / (1/precision + 1/recall)
        mcc = ( self.tp * self.tn - self.fp * self.fn )  / jnp.sqrt(self.pred_pos * self.actual_pos * self.actual_neg * self.pred_neg)
        return mcc

In [None]:
# Flax dataclasses ## to compute metrics in eval model
@struct.dataclass
class Metrics(metrics.Collection):
    loss: metrics.Average.from_output('loss')
    f1_score: F1_score
    mcc: MCC

In [None]:
class TrainState(train_state.TrainState):
  metrics: Metrics

In [None]:
# class Writer:
#     cache: Dict[int, List[Tuple[Any, str]]
#     def __init__(self, save_dir):
#         self.save_dir = save_dir
#     def add_scalar(tag: str, val: Any, key: int):
#         cache[key] = (val, tag)
        

In [None]:
# trainer class
class Trainer:
    def __init__(self, model, params, logger, key):
        self.model = model
        self.hparams = params
        self.key = key
        self.logger = logger # event logger
        self.writer = SummaryWriter(f"{self.hparams.log_dir}/{params.run_name}") #log to tensorboard
        self.hparams.chkpt_dir = f"{self.hparams.chkpt_dir}/{params.run_name}"
        self._init_train_state()
        self._configure_checkpointer()
    

    def _configure_checkpointer(self):
        options = ocp.CheckpointManagerOptions(max_to_keep=self.hparams.max_epochs, save_interval_steps=1, enable_async_checkpointing=True, create=True)
        path = epath.Path(os.path.abspath(self.hparams.chkpt_dir))
        self.checkpoint_mngr = ocp.CheckpointManager(path, options=options, item_names =('state', 'hparams'))
        self.logger.info("Checkpointer configured")

In [None]:
%%add_to Trainer
def _init_train_state(self):
    self.n_devices = jax.local_device_count(backend='tpu')
    self.logger.info(f"Number of devices found: {self.n_devices}")
    mesh = jax.make_mesh((self.n_devices,), ('batch',))
    model_sharding = NamedSharding(mesh, P())
    self.key, model_key = jr.split(self.key)
    variables = self.model.init(model_key, jnp.ones((1,) + self.hparams.shape))
    self._configure_optimizers(variables)
    model_state = TrainState.create(
        apply_fn=self.model.apply, params=variables["params"], 
        tx=self.optim, metrics=Metrics.empty()
    )
    self.model_state = jax.device_put(model_state, model_sharding)
    self.logger.info("Train state initialized")
    

In [None]:
%%add_to Trainer
def _configure_optimizers(self, variables):
    scheduler = optax.exponential_decay(
                    init_value=self.hparams.lr,
                    transition_steps=1000,
                    decay_rate=0.97
                )
    self.optim = optax.chain(
        optax.clip_by_global_norm(1.0),  # Clip by the gradient by the global norm.
        optax.scale_by_adam(),  # Use the updates from adam.
        optax.scale_by_schedule(scheduler),  # Use the learning rate from the scheduler.
        # Scale updates by -1 since optax.apply_updates is additive and we want to descend on the loss.
        optax.scale(-1.0)
    )
    self.opt_state = self.optim.init(variables)
    self.logger.info("Optimizers configured")

In [None]:
%%add_to Trainer
@staticmethod
@jax.jit
def _step(model_state, batch):
    # Executes a training loop.
    x, y =  batch
    y = y.reshape(-1)
    x = x.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1])
    def loss_fn(params):
        logits = model_state.apply_fn({"params": params}, x)
        loss = optax.softmax_cross_entropy_with_integer_labels(
            logits=logits, labels=y).mean()
        return loss
    grads = jax.grad(loss_fn)(model_state.params)
    model_state = model_state.apply_gradients(grads=grads)
    return model_state

In [None]:
%%add_to Trainer
def fit(self, train_loader: Any, val_loader: Any):
    best_model = (float('inf'), 0)
    
    mesh = jax.make_mesh((self.n_devices,), ('batch',))
    data_sharding = NamedSharding(mesh, P('batch'))
    
    num_steps_per_epoch = train_loader.cardinality().numpy() // MAX_EPOCHS
    
    epoch_pbar = tqdm(range(1, MAX_EPOCHS+1), leave=True, total = MAX_EPOCHS)
    for epoch in epoch_pbar:
        epoch_pbar.set_description(f"Epoch: {epoch}")
        for step, batch in enumerate(train_loader.as_numpy_iterator()):
            batch = jax.device_put(batch, data_sharding)
            self.model_state = Trainer._step(self.model_state, batch)
            self.model_state, _ = Trainer._compute_metrics(self.model_state, batch)
        metrics = self.model_state.metrics.compute()
        tloss = jax.device_get(metrics['loss'])
        # precision = jax.device_get(metrics['precision'])
        # recall = jax.device_get(metrics['recall'])
        # mcc = jax.device_get(metrics['mcc'])
        # f1_score = 2 / (1/precision + 1/recall)
        # f1_score = 2 / (1/precision + 1/recall)
        self.writer.add_scalar(f"train_loss/epoch_{epoch}", np.asarray(tloss), epoch)
        self.writer.add_scalar(f"train_f1_score/epoch_{epoch}", np.asarray(jax.device_get(metrics['f1_score'])), epoch)
        self.writer.add_scalar(f"train_mcc/epoch_{epoch}", np.asarray(jax.device_get(metrics['mcc'])), epoch)
        self.model_state = self.model_state.replace(metrics=self.model_state.metrics.empty())
        # evaluation loop
        test_state = self.model_state #copy of model state for eval
        for batch in val_loader.as_numpy_iterator():
            batch = jax.device_put(batch, data_sharding)
            test_state, preds =  Trainer._compute_metrics(test_state, batch)
            self.writer.add_pr_curve('pr_curve', np.asarray(batch[1].reshape(-1)), np.asarray(preds), epoch)
        metrics = test_state.metrics.compute()
        vloss = jax.device_get(metrics['loss'])
        epoch_pbar.set_postfix(train_loss=tloss, val_loss=vloss)
        self.writer.add_scalar(f"val_loss/epoch_{epoch}/", np.asarray(vloss), epoch)
        self.writer.add_scalar(f"val_f1_score/epoch_{epoch}", np.asarray(jax.device_get(metrics['f1_score'])), epoch)
        self.writer.add_scalar(f"val_mcc/epoch_{epoch}/", np.asarray(jax.device_get(metrics['mcc'])), epoch)
        self.checkpoint_mngr.save(
            epoch,
            args=ocp.args.Composite(
                state = ocp.args.StandardSave(test_state), # use the copy of the model with the eval stats
                hparams = ocp.args.JsonSave(json.dumps(self.hparams))
                ) 
            )
        # track best model for checkpointing
        best_model = min(best_model, (vloss, epoch))
            # logger.add_hparams(self.hparams, metric_dict, run_name=self.hparams.run_name)
    try:
        self.checkpoint_mngr.wait_until_finished()
        if self.hparams.save_best:
            self.checkpoint_mngr.should_save(best_model[1]) #save only the best model
            self.logger.info(f"best model saved at epoch {best_model[1]} with loss {best_model[0]}")
    except Exception as e:
        self.logger.error(f"Error saving checkpoint: {e}")
        

In [None]:
%%add_to Trainer
@staticmethod
@jax.jit
def _compute_metrics(model_state, batch):
    x, y =  batch
    x = x.reshape(-1, x.shape[-3], x.shape[-2], x.shape[-1])
    y = y.reshape(-1)
    logits = model_state.apply_fn({"params": model_state.params}, x)
    preds = jnp.argmax(logits, axis=-1)
    loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=y).mean()
    metric_updates = model_state.metrics.single_from_model_output(
        logits=logits, labels=y, loss=loss)
    metrics = model_state.metrics.merge(metric_updates)
    model_state = model_state.replace(metrics=metrics)
    return model_state, preds

In [None]:
%%add_to Trainer
def predict(self, x):
    return self.model_state.apply_fn({"params": self.model_state.params}, x)

In [None]:
%%add_to Trainer    
'''
@brief: load trained model and saved params from checkpoint and return an instance of the Trainer class
'''
@staticmethod
def load_from_checkpoint(path, load_epoch: int = 0):
    pass

In [None]:
def get_class_weights(ds):
    def count(counts, batch):
        x, y = batch
        class_1 = tf.cast(y == 1, tf.int32)
        class_0 = tf.cast(y == 0, tf.int32)
        counts['y_0'] += tf.reduce_sum(class_0)
        counts['y_1'] += tf.reduce_sum(class_1)
        return counts
    counts = ds.reduce(
        initial_state={'y_0': 0, 'y_1': 0},
        reduce_func = count
    )
    counts = np.array([counts['y_0'].numpy(),
                   counts['y_1'].numpy()]).astype(np.float32)
    weights = counts/counts.sum()
    return weights

In [None]:
def get_dataset(dfs: Dict[str, Any], num_epochs: int=10, batch_size:int=32, split_ratio: float = 0.2,):
    AUTOTUNE = tf.data.AUTOTUNE
    train_datagen = ImageDataGenerator(rescale=1/255, validation_split=split_ratio, rotation_range=20, 
                                       width_shift_range=0.2, height_shift_range=0.2, 
                                       horizontal_flip=True, vertical_flip=True, 
                                       zoom_range=0.2, shear_range=0.2, fill_mode='nearest'
                                      ) # data augmentation
    train_generator = train_datagen.flow_from_dataframe(dataframe=dfs['train'][0], directory=dfs['train'][1], 
                                                        x_col='image_path', y_col='label', batch_size=1,
                                                        target_size=(img_width, img_height), 
                                                        subset='training', class_mode='binary')
    val_generator = train_datagen.flow_from_dataframe(dataframe=dfs['train'][0], directory=dfs['train'][1], 
                                             x_col='image_path',y_col='label', batch_size=1,
                                            target_size=(img_width, img_height), subset='validation', class_mode='binary')
    test_datagen = ImageDataGenerator(rescale=1/255)
    test_generator = test_datagen.flow_from_dataframe(dataframe=dfs['test'][0], directory=dfs['test'][1], 
                                             x_col='path',y_col=None, batch_size=1,
                                            target_size=(img_width, img_height), class_mode=None, shuffle=None)
    train_ds = tf.data.Dataset.from_generator(
        lambda: train_generator,
        output_types=(tf.float32, tf.int32),
        output_shapes=([None,img_width,img_height,3], [None])
    ).cache()
    val_ds = tf.data.Dataset.from_generator(
        lambda: val_generator,
        output_types=(tf.float32, tf.int32),
        output_shapes=([None,img_width,img_height,3], [None])
    ).cache()
    test_ds = tf.data.Dataset.from_generator(
        lambda: test_generator,
        output_types=(tf.float32),
        output_shapes=([None,img_width,img_height,3])
    ).cache()
    # class_weights = get_class_weights(train_ds)
    train_ds = train_ds.shuffle(1024).batch(batch_size, drop_remainder=True).prefetch(buffer_size=AUTOTUNE)
    # train_ds = train_ds.map(lambda extra_label, features_and_label: features_and_label) # drop extra_label returned from rejection_resample method
    val_ds = val_ds.shuffle(1024).batch(batch_size, drop_remainder=True).prefetch(buffer_size=AUTOTUNE)
    test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(buffer_size=AUTOTUNE)
    return train_ds.take(batch_size*2), val_ds.take(batch_size * 2), test_ds.take(batch_size * 2)

In [None]:
class CNN(nn.Module):
    """A simple CNN model."""
    img_size: int = 32
    out_dim: int = 1
    
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=self.img_size, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=self.out_dim)(x)
        return x


def create_model(img_size, out_dim):
    return CNN(img_size=img_size, out_dim=out_dim)

In [None]:
from munch import DefaultMunch

hparams = DefaultMunch.fromDict({
    'batch_size': BATCH_SIZE,
    'image_size': img_width,
    'shape' : (img_width,) * 2 + (3,),
    'run_name' : datetime.now().strftime("%Y%m%d-%H%M%S"),
    'out_dim' : 2,
    'lr': 0.005,
    'max_epochs': MAX_EPOCHS,
    'log_dir': '/kaggle/working/',
    'chkpt_dir': '/kaggle/working/'
})

In [None]:
import logging

logger = logging.getLogger(__name__)
logging.basicConfig(filename='cancer.log', level=logging.INFO)
model = create_model(img_height, hparams.out_dim)

In [None]:
print(model.tabulate(jax.random.key(0), jnp.ones((1, img_width, img_height, 3)),
                   compute_flops=True, compute_vjp_flops=True))

In [None]:
dfs = {'train': (df_train, TRAIN_DIR), 'test': (df_test, TEST_DIR)}
train_ds, val_ds, test_ds = get_dataset(dfs, num_epochs = MAX_EPOCHS, batch_size=BATCH_SIZE)

In [None]:
key = jax.random.key(0)
trainer = Trainer(model, hparams, logger, key)

In [None]:
hparams.log_dir+hparams.run_name

In [None]:
%load_ext aim

In [None]:
# self.hparams.log_dir}/{dt}
%aim convert tensorboard --logdir /kaggle/working/20250604-181356

In [None]:
%aim up

In [None]:
trainer.fit(train_ds, val_ds)

In [None]:
# class_weights = get_class_weights(train_ds.batch(BATCH_SIZE))
# train_ds = train_ds.unbatch().rejection_resample(lambda x, y: y, 
#                                                     target_dist=[0.5, 0.5], 
#                                                     initial_dist=class_weights
#                                                   ).shuffle(1024).batch(batch_size, drop_remainder=True)
#     # train_ds = train_ds.map(lambda extra_label, features_and_label: features_and_label) # drop extra_label returned from rejection_resample method

### Hyperparameter Optimization Using Optuna

In [None]:
NUM_TRIALS = 100

In [None]:
def objective(trial):
    return 0

In [None]:
optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
study_name = "cancer_project"  # Unique identifier of the study.
storage = optuna.storages.InMemoryStorage()
storage_name = f"sqlite:///{study_name}.db"
study = optuna.create_study(study_name=study_name, storage=storage_name, directions=["minimize"], load_if_exists=True)
study.set_metric_names(["f1_score"])

study.optimize(objective, n_trials=NUM_TRIALS, timeout=600)

In [None]:
print(f"Number of trials on the Pareto front: {len(study.best_trials)}")

t1 = max(study.best_trials, key=lambda t: t.values[0])
print("Trial with highest accuracy: ")
print(f"\tnumber: {t1.number}")
print(f"\tparams: {t1.params}")
print(f"\tvalues: {t1.values}")
t2 = max(study.best_trials, key=lambda t: t.values[1])
print("Trial with best generalization capability (val_accuracy-train_accuracy): ")
print(f"\tnumber: {t2.number}")
print(f"\tparams: {t2.params}")
print(f"\tvalues: {t2.values}")

In [None]:
# run this cell to view the optuna dashboard in another window (following the link in the output cell)
# https://stackoverflow.com/questions/76033104/launching-optuna-dashboard-in-google-colaboratory
port =  9005
storage = optuna.storages.RDBStorage(f"sqlite:////content/{study_name}.db")
app = wsgi(storage)
httpd = make_server("localhost", port, app)
thread = threading.Thread(target=httpd.serve_forever)
thread.start()
time.sleep(3) # Wait until the server startup

from google.colab import output
output.serve_kernel_port_as_window(port, path='/dashboard/') # follow the link in the output cell below to view the dashboard

**Alternatively, run the cells below to see the output of the hyperparameter tuning**

In [None]:
optuna.visualization.plot_pareto_front(study, target_names=["val_accuracy_score", "train_score-val_score"])

In [None]:
optuna.visualization.plot_param_importances(
    study, target=lambda t: t.values[0], target_name="val_accuracy_score"
)

In [None]:
optuna.visualization.plot_param_importances(
    study, target=lambda t: t.values[1], target_name="train_score-val_score"
)

In [None]:
optuna.visualization.plot_contour(study, target=lambda t: t.values[0], target_name="val_accuracy_score")

In [None]:
optuna.visualization.plot_contour(study, target=lambda t: t.values[1], target_name="train_score-val_score")

In [None]:
plot_parallel_coordinate(study, target=lambda t: t.values[0], target_name="val_accuracy_score")

In [None]:
plot_parallel_coordinate(study, target=lambda t: t.values[1], target_name="train_score-val_score")

## Disussion and Conclusion



## References
1. 