In [1]:
import random
import os
os.environ["OMP_NUM_THREADS"] = "1"
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import torch
import cv2
import warnings
import timeit
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import optim
from scipy.special import expit
from skimage import io, transform, measure
from sklearn import metrics
import optuna
from config import *
from util import *
from models import *
sns.set()
warnings.filterwarnings("ignore")
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 500)
rc = {"figure.figsize" : (9, 7),
      "axes.spines.left" : False,
      "axes.spines.right" : False,
      "axes.spines.bottom" : False,
      "axes.spines.top" : False,
      "xtick.bottom" : True,
      "xtick.labelbottom" : False,
      "ytick.labelleft" : False,
      "ytick.left" : True,
      "axes.grid" : False}
plt.rcParams.update(rc)

In [None]:
best_accuracy = 0

def objective(trial):
    torch.cuda.empty_cache()
    random.seed(seed)

    params = {
        #'epochs': trial.suggest_int('epochs', 10, 30),
        'learning_rate': trial.suggest_loguniform('learning_rate', 1e-5, 5e-4),
        'dx_dropout': trial.suggest_float('dx_dropout', 0.3, 0.7),
        'char_dropout': trial.suggest_float('char_dropout', 0.3, 0.7),
        'batch_size': trial.suggest_categorical('batch_size', [16, 32])
    }
    
    checkpoint_callback = ModelCheckpoint(dirpath="checkpoints", save_top_k=1, monitor="val/loss")
    model = HAMFineTune(img_dir=img_dir, annotations_dir=annotations_dir, metadata_file=metadata_file, weighted_sampling=False,
                        batch_size=params['batch_size'], dx_dropout=params['dx_dropout'], char_dropout=params['char_dropout'],
                        learning_rate=params['learning_rate'])

    trainer = pl.Trainer(max_epochs=25, devices=[1], accelerator="gpu", deterministic=True, callbacks=[checkpoint_callback])
    trainer.fit(model)
    
    # Load the best checkpoint after training
    best_checkpoint_path = checkpoint_callback.best_model_path
    checkpoint = torch.load(best_checkpoint_path)
    model.load_state_dict(checkpoint['state_dict'])
    
    result_val = get_predictions(trainer, model, split='val', char_threshold=0, dx_threshold=0)
    char_threshold = get_thresholds(result_val)
    dx_threshold = find_optimal_cutoff(result_val['benign_malignant'], result_val['dx_pred'])
    result_test = get_predictions(trainer, model, split='test', char_threshold=char_threshold, dx_threshold=dx_threshold)
    result_external = get_predictions(trainer, model, split='external', char_threshold=char_threshold, dx_threshold=dx_threshold) 
    
    val_accuracy = metrics.balanced_accuracy_score(result_val['benign_malignant'], result_val['prediction'])
    test_accuracy = metrics.balanced_accuracy_score(result_test['benign_malignant'], result_test['prediction'])
    external_accuracy = metrics.balanced_accuracy_score(result_external['benign_malignant'], result_external['prediction'])
    
    trial.set_user_attr("val_accuracy", val_accuracy)
    trial.set_user_attr("test_accuracy", test_accuracy)
    trial.set_user_attr("external_accuracy", external_accuracy)
    
    global best_accuracy
    if val_accuracy >= best_accuracy:
        best_accuracy = val_accuracy
        torch.save(model.state_dict(), os.path.join("models", "baseline_model.pth"))
        result_val.to_csv(os.path.join("results", "baseline_model", "result_val.csv"), index=False)
        result_test.to_csv(os.path.join("results", "baseline_model", "result_test.csv"), index=False)
        result_external.to_csv(os.path.join("results", "baseline_model", "result_external.csv"), index=False)
    
    return val_accuracy

study = optuna.create_study(direction="maximize", sampler=optuna.samplers.TPESampler(seed=seed), storage="sqlite:///optuna/baseline_model.db")
study.optimize(objective, n_trials=15, gc_after_trial=True)

[32m[I 2024-08-09 11:04:59,267][0m A new study created in RDB with name: no-name-44f0743d-b3a8-415d-b3e3-c5bdf409a0ac[0m
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

   | Name                 | Type              | Params
------------------------------------------------------------
0  | lossD                | BCEWithLogitsLoss | 0     
1  | lossC                | BCEWithLogitsLoss | 0     
2  | lossA                | DiceLoss          | 0     
3  | base_model           | ResNet            | 23.5 M
4  | diagnosis_head       | Sequential        | 2.0 K 
5  | characteristics_head | Sequential        | 20.5 K
6  | sigmoid              | Sigmoid           | 0     
7  | accuracy             | Accuracy          | 0     
8  | auroc                | AUROC             | 0     
9  | sensitivity          | Recall            | 0     
10 | specificity          | Specificity       | 0   

Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

In [None]:
print(study.best_params)
print(study.best_value)
trials_df = study.trials_dataframe().sort_values('value', ascending=False)
display(trials_df)
trials_df.to_csv(os.path.join('optuna', 'baseline_model.csv'), index=False)