# Setup to train full model with findinng proper weights to train each head (Classification + Segmentation)
## Breast-Ultrasound-Segmentation

## About Dataset
Breast cancer is one of the most common causes of death among women worldwide. Early detection helps in reducing the number of early deaths. The data reviews the medical images of breast cancer using ultrasound scan. Breast Ultrasound Dataset is categorized into three classes: normal, benign, and malignant images. Breast ultrasound images can produce great results in classification, detection, and segmentation of breast cancer when combined with machine learning.

### Data
The data collected at baseline include breast ultrasound images among women in ages between 25 and 75 years old. This data was collected in 2018. The number of patients is 600 female patients. The dataset consists of 780 images with an average image size of 500*500 pixels. The images are in PNG format. The ground truth images are presented with original images. The images are categorized into three classes, which are normal, benign, and malignant.

If you use this dataset, please cite:
Al-Dhabyani W, Gomaa M, Khaled H, Fahmy A. Dataset of breast ultrasound images. Data in Brief. 2020 Feb;28:104863. DOI: 10.1016/j.dib.2019.104863.

## Imports

In [None]:
import os

import pyrootutils

root = pyrootutils.setup_root(
    search_from=os.path.dirname(os.getcwd()),
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)

if os.getenv("DATA_ROOT") is None:
    os.environ["DATA_ROOT"] = f"{root}/data"

In [None]:
import torch
import torch.nn as nn

# Setup device-agnostic code
if torch.cuda.is_available():
    DEVICE = "cuda"  # NVIDIA GPU
    print("GPU Found!!")
else:
    raise Exception("No GPU Found!!")

In [None]:
import logging

import hydra
from hydra import compose, initialize

log = logging.getLogger(__name__)

In [None]:
# # auto reload dotenv
%load_ext dotenv
%dotenv

# auto reload libs
%load_ext autoreload
%autoreload 2

## Paths setup

In [None]:
from omegaconf import DictConfig, OmegaConf

# Register a resolver for torch dtypes
OmegaConf.register_new_resolver("torch_dtype", lambda name: getattr(torch, name))

In [None]:
with initialize(config_path="../configs", job_name="training_setup", version_base=None):
    cfg: DictConfig = compose(config_name="train.yaml")
    # print(OmegaConf.to_yaml(cfg))
    print(cfg)

In [None]:
os.chdir(root)

## Loading Dataset

In [None]:
data_module = hydra.utils.instantiate(cfg.datamodule)

class_weights = data_module.class_weights
class_names = data_module.classes
num_classes = len(class_names)
class_names, num_classes, class_weights

In [None]:
next(iter(data_module.train_dataloader()))

In [None]:
next(iter(data_module.val_dataloader()))

In [None]:
next(iter(data_module.test_dataloader()))

In [None]:
train_dl, val_dl = data_module.get_sampled_dataloader()

In [None]:
images, targets = next(iter(train_dl))
print(images.shape, targets["masks"].shape, targets["labels"].shape)

print(f"images:{images.dtype}, {images[0].min()}, {images[0].max()}")
print(f'masks {targets["masks"].dtype}, {targets["masks"][0].min()}, {targets["masks"][0].max()}')
print(f'labels {targets["labels"].dtype}, {targets["labels"].min()}, {targets["labels"].max()}')

In [None]:
_images, _targets = next(iter(val_dl))

print(_images.shape, _targets["masks"].shape, _targets["labels"].shape)

print(f"images:{_images[0].dtype}, {_images[0].min()}, {_images[0].max()}")
print(f'masks {_targets["masks"].dtype}, {_targets["masks"].min()}, {_targets["masks"].max()}')
print(f'labels {_targets["labels"].dtype}, {_targets["labels"].min()}, {_targets["labels"].max()}')

## Loading and training the FCN8 model 

In [None]:
segmentation_criterion = hydra.utils.instantiate(cfg.losses.segmentation_criterion)
classification_criterion = hydra.utils.instantiate(
    cfg.losses.classification_criterion, weight=class_weights
)
classification_criterion.weight

In [None]:
import mlflow
import mlflow.pytorch

from src.utils.gpu_utils import DeviceDataLoader, get_default_device, to_device

torch.cuda.empty_cache()
device = get_default_device()

gpu_weights = to_device(class_weights, device)

In [None]:
cfg.models

In [None]:
model = hydra.utils.instantiate(
    cfg.models.model,
    segmentation_criterion=segmentation_criterion,
    classification_criterion=classification_criterion,
)

In [None]:
model = torch.compile(model)
model

In [None]:
from mlflow.models import infer_signature

task_name = cfg.task_name
mlflow.set_experiment(f"overfitting-{task_name}")
run = mlflow.start_run()
model.eval()  # This ensures layers like Dropout and BatchNorm behave correctly for inference and saves computation.
with torch.no_grad():
    images, labels = next(iter(train_dl))
    # print(images.shape, labels)
    out = model(images)
    # l = labels['labels'][0]
    # print(l , torch.argmax(l))
    print(out["labels"][0], out["labels"].shape)
    print(out["masks"][0])
    print(out["masks"].shape)
    signature = infer_signature(
        model_input={"image_input": images.numpy()},
        model_output={"output": {"masks": out["masks"].numpy(), "labels": out["labels"].numpy()}},
    )
signature

## GPU Training Setup

## Moving data and model into memory

In [None]:
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)
to_device(model, device)
# train_dl.device

## Overfiting the model

In [None]:
optimizer = hydra.utils.instantiate(cfg.models.optimizer, params=model.parameters(), lr=1e-4)

In [None]:
cfg.paths

In [None]:
cfg.paths.output_dir

In [None]:
from torchinfo import summary

from src.utils.train_utils import fit

EPOCHS = cfg.trainer.max_epochs
mlflow.log_params({"epochs": EPOCHS})
mlflow.log_params({"batch_size": cfg.datamodule.batch_size})
mlflow.log_params({"optimizer": cfg.models.optimizer.values()})
# Log model summary.
with open("model_summary.txt", "w") as f:
    f.write(str(summary(model)))
mlflow.log_artifact("model_summary.txt")
history = fit(
    model=model,
    train_dataloader=train_dl,
    validation_dataloader=val_dl,
    epochs=EPOCHS,
    optimizer=optimizer,
    device_type=device.type,
    dtype=torch.float16,
    reduce_lr_on_plateau=torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.1, patience=5
    ),
)
# saving the trained model
mlflow.pytorch.log_model(model, "model", signature=signature)
mlflow.log_metrics(history[0])

In [None]:
mlflow.end_run()

In [None]:
import matplotlib.pyplot as plt

seg_losses = [x["seg_loss"] for x in history]
seg_dice = [x["seg_dice"] for x in history]

plt.plot(seg_losses, "-bx")
plt.plot(seg_dice, "-rx")

plt.xlabel("epoch")
plt.ylabel("loss")
plt.grid()
plt.legend(["seg_loss", "seg_dice"])
plt.title("Loss vs. NO. of epochs")