In [1]:
import logging
from IPython.core.interactiveshell import InteractiveShell
%load_ext autoreload
InteractiveShell.ast_node_interactivity = "all"

logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)

In [2]:
import pandas as pd
import numpy as np
import os
import sys
import matplotlib.pyplot as plt

module_path = os.path.abspath(os.path.join(".."))
if module_path not in sys.path:
    sys.path.append(module_path)

2024-06-07 10:08:28,651 - numexpr.utils - INFO - Note: NumExpr detected 32 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2024-06-07 10:08:28,653 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.


# Load data

In [3]:
result_parent_dir = "/cmnfs/proj/ORIGINS/data/brain/FreshFrozenBrain/SingleShot/DDA/"
result_base_dir = "frame0_1830_ssDDA_P064428_Fresh1_5ug_R1_BD5_1_4921_ScanByScan_RTtol0.9_threshold_missabthres0.5_convergence_NoIntercept_pred_mzBinDigits2_imPeakWidth4_deltaMobilityThres80"

peak_data_dir = "peak_detection_mask_data_rt_full_overlap"
result_dir = os.path.join(result_parent_dir, result_base_dir)

peak_selection_spec_dir = "IMRT_fulloverlap_data_peak_selection_seg_model_1out32_lr0.005_bs256_comboloss_bce1_dice4_focal1_metric_wdice_channel1_0.5"
peak_selection_dir = os.path.join(result_dir, peak_selection_spec_dir)
best_model_path = os.path.join(peak_selection_dir, "bst_model_0.7089.bin")

In [8]:
%autoreload 2
import torch
from torchvision.transforms import Compose
from peak_detection_2d.dataset import MultiHDF5_MaskDataset, Mask_Padding, Mask_AddLogChannel, Mask_ToTensor, Mask_LogTransform

patience = 10
batch_size = 32

random_state = 42


hdf5_files = [
    os.path.join(os.path.join(result_dir, "peak_detection_mask_data"), file)
    for file in os.listdir(os.path.join(result_dir, "peak_detection_mask_data"))
    if file.endswith(".h5")
]

# Define transformations (if any)
transformation = Compose([Mask_Padding((258, 258)),Mask_AddLogChannel(), Mask_ToTensor()])

# Create the dataset
dataset = MultiHDF5_MaskDataset(hdf5_files, transforms=transformation)

# Split the dataset into training and testing sets
train_val_dataset, test_dataset = dataset.split_dataset(
    train_ratio=0.9, seed=random_state
)
train_dataset, val_dataset = train_val_dataset.split_dataset(
    train_ratio=0.9, seed=random_state
)

# Example usage
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=False
)
val_dataloader = torch.utils.data.DataLoader(
    val_dataset, batch_size=128, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(
    test_dataset, batch_size=128, shuffle=False
)

# Training

In [5]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
%autoreload 2
from peak_detection_2d.utils import EarlyStopping
from peak_detection_2d.loss import per_image_weighted_dice_metric
from peak_detection_2d.seg_model import train_one_epoch, evaluate, UNET
from torch.optim.lr_scheduler import ReduceLROnPlateau
from peak_detection_2d.combo_loss import ComboLoss
TRAIN_MODEL = True
EVALUATE = True
EPOCHS = 20
patience = 10
initial_lr = 0.005
model = UNET(2, 32, 1, padding=1, downhill=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=initial_lr)
if best_model_path is not None:
    model.load_state_dict(torch.load(best_model_path))
    logging.info("Loaded model from %s", best_model_path)
scheduler = ReduceLROnPlateau(
    optimizer, mode="max", factor=0.1, patience=3, min_lr=0.000001
)
#criterion = nn.BCEWithLogitsLoss()
criterion = ComboLoss(**{'weights':{'bce':1, 'dice':4, 'focal':1}, "channel_weights": [1, 0.5]})
es = EarlyStopping(patience=patience, mode="max")
train_loss = []
metric = {"train":[], "val":[]}
if TRAIN_MODEL:
    for epoch in range(EPOCHS):
        loss = train_one_epoch(train_dataloader, model, optimizer, criterion, use_image_as_input=True)
        train_metric = evaluate(train_dataloader, model, metric=per_image_weighted_dice_metric, use_image_for_metric=True, device=device, channel = 0)
        val_metric = evaluate(val_dataloader, model, metric=per_image_weighted_dice_metric, use_image_for_metric=True, device = device, channel = 0)
        metric["train"].append(train_metric)
        metric["val"].append(val_metric)
        train_loss.append(loss)
        scheduler.step(val_metric)
        print(f"EPOCH: {epoch}, TRAIN LOSS: {loss}, TRAIN DICE: {train_metric}, VAL DICE: {val_metric}")
        es(
            val_metric,
            model,
            model_path=os.path.join(
                result_dir, peak_data_dir, f"bst_model_{np.round(val_metric,4)}.bin"
            ),
        )
        best_model = os.path.join(
            result_dir, peak_data_dir, f"bst_model_{np.round(es.best_score,4)}.bin"
        )
        if es.early_stop:
            print("\n\n -------------- EARLY STOPPING -------------- \n\n")
            break


<All keys matched successfully>

2024-06-07 10:33:37,160 - root - INFO - Loaded model from /cmnfs/proj/ORIGINS/data/brain/FreshFrozenBrain/SingleShot/DDA/frame0_1830_ssDDA_P064428_Fresh1_5ug_R1_BD5_1_4921_ScanByScan_RTtol0.9_threshold_missabthres0.5_convergence_NoIntercept_pred_mzBinDigits2_imPeakWidth4_deltaMobilityThres80/IMRT_fulloverlap_data_peak_selection_seg_model_1out32_lr0.005_bs256_comboloss_bce1_dice4_focal1_metric_wdice_channel1_0.5/bst_model_0.7089.bin
  7%|▋         | 58/786 [00:52<10:59,  1.10it/s, learning_rate=0.005, loss=1.8] 


KeyboardInterrupt: 

In [None]:
testset_idx = ind_all.loc[ind_all["ranks"] == 128620].index[0]
# Plot sample predictions
plot_sample_predictions(
    test_dataset_log,
    model=log_trans_model,
    sample_indices=[testset_idx],
    # n = 10,
    save_dir=None,
    metric_list=["mask_wiou", "wdice", "dice"],
    use_hint=False,
    zoom_in=False,
    label="mask",
    device=DEVICE,
    # save_dir=os.path.join(log_trans_dir, "sample_predictions_highest_wiou"),
)
plot_sample_predictions(
    test_dataset,
    model=normal_model,
    sample_indices=[testset_idx],
    # n = 10,
    save_dir=None,
    metric_list=["mask_wiou", "wdice", "dice"],
    use_hint=False,
    zoom_in=False,
    label="mask",
    device=DEVICE,
    # save_dir=os.path.join(log_trans_dir, "sample_predictions_highest_wiou"),
)

In [None]:
%autoreload 2
from peak_detection_2d.utils import plot_sample_predictions
from peak_detection_2d.loss import per_image_weighted_iou_metric
bst_model = UNET(2, 32, 1, padding=1, downhill=4).to(device)
checkpoint=torch.load(best_model, map_location=device)
bst_model.load_state_dict(checkpoint)
ind_avg_wiou, ind_all_wiou_log = evaluate(
    model=bst_model,
    valid_loader=test_dataloader,
    device=device,
    metric=per_image_weighted_iou_metric,
    save_all_loss=True,
    use_image_for_metric=True,
    channel = 0
)
ind_all_wiou_df = pd.DataFrame(ind_all_wiou_log)

In [None]:
from peak_detection_2d.utils import plot_per_image_metric_distr

plot_per_image_metric_distr(ind_all_wiou_log["losses"], "Weighted_IoU", save_dir=None)

In [None]:
np.array([0, 2, 2]) > 1

In [None]:
testset_idx = ind_all_wiou_df.loc[ind_all_wiou_df["ranks"] == 34051].index[0]
plot_sample_predictions(
    test_dataset,
    model=bst_model,
    # n=5,
    sample_indices=[testset_idx],
    metric_list=["wdice", "mask_wiou", "dice"],
    use_hint=False,
    save_dir=None,
    zoom_in=False,
    label="mask",
    channel=1
    # save_dir=os.path.join(result_dir, peak_data_dir, "sample_predictions"),
)

In [None]:
%autoreload 2
if EVALUATE:
    test_score = evaluate(test_dataloader, model, metric=weighted_dice_loss)
    print(f"Valid dice score: {test_score}")

In [None]:
%autoreload 2
from peak_detection_2d.loss import WeightedBoundingBoxIoULoss
import logging
import json
import os

import torch
from torch import nn
from torch.optim.lr_scheduler import ReduceLROnPlateau

from peak_detection_2d.utils import (
    plot_sample_predictions,
    plot_history,
)
from peak_detection_2d.dataset import MultiHDF5Dataset, ToTensor, Padding
from peak_detection_2d.seg_model import (
    UNET,
    train_val_step,
    train_one_epoch)
num_epoch = 1
patience = 10
inital_lr = 0.001
batch_size = 4

random_state = 42
device = "cuda" if torch.cuda.is_available() else "cpu"

loss_fn = torch.nn.BCEWithLogitsLoss()
scaler = torch.cuda.amp.GradScaler()
model = UNET(1, 16, 1, padding=1, downhill=3).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=inital_lr)
scheduler = ReduceLROnPlateau(
    optimizer, mode="min", factor=0.1, patience=3, min_lr=0.000001
)

loss_tracking = {"train": [], "val": []}
iou_tracking = {"train": [], "val": []}
best_loss = float("inf")
for epoch in range(num_epoch):
    logging.info("Epoch %d/%d", epoch + 1, num_epoch)

    training_loss, trainig_iou = train_one_epoch(
        train_dataloader, model, loss_fn, optimizer
    )
    loss_tracking["train"].append(training_loss)
    iou_tracking["train"].append(trainig_iou)

    with torch.inference_mode():
        val_loss, val_iou = train_val_step(val_dataloader, model, loss_fn, None)
        loss_tracking["val"].append(val_loss)
        iou_tracking["val"].append(vala_iou)
        if val_loss < best_loss:
            logging.info("Saving best model")
            torch.save(
                model.state_dict(), os.path.join(result_dir, peak_data_dir, "best_model.pt")
            )
            best_loss = val_loss
            current_patience = patience
        else:
            current_patience -= 1
            if current_patience == 0:
                logging.info("Early stopping")
                break
        scheduler.step(val_loss)
        logging.info(
            "Last learning rate: %s",
            scheduler.get_last_lr(),
        )

    logging.info("Training loss: %.6f, IoU: %.2f", training_loss, trainig_iou)
    logging.info("Validation loss: %.6f, IoU: %.6f", val_loss, val_iou)

In [None]:
plot_sample_predictions(
    test_dataset,
    model=model,
    n=10,
    save_dir=os.path.join(peak_selection_dir, "sample_predictions"),
)

In [None]:
import gc

torch.cuda.empty_cache()
gc.collect()

In [None]:
from peak_detection_2d.utils import plot_data_points

# Sample n datapoints from test_dataset
n = 5
sample_indices = np.random.choice(len(test_dataset), n, replace=False)
for i in sample_indices:
    image, hint, label = test_dataset[i]
    output = model(image.unsqueeze(0).float(), hint.unsqueeze(0).float())
    iou = iou_batch(output, label.unsqueeze(0))
    to_plot = {"data": image[0].cpu(), "hint_idx": hint.cpu(), "bbox": label.cpu()}
    plot_data_points(to_plot, pred_bbox=output[0].cpu().detach().numpy(), zoom_in=True)
    plt.title(f"IoU: {iou:.2f}")
    plt.savefig(f"sample_{i}.png", dpi=300)
    plt.close()