# Setup to train full model with findinng proper weights to train each head (Classification + Segmentation)
## Breast-Ultrasound-Segmentation

## About Dataset
Breast cancer is one of the most common causes of death among women worldwide. Early detection helps in reducing the number of early deaths. The data reviews the medical images of breast cancer using ultrasound scan. Breast Ultrasound Dataset is categorized into three classes: normal, benign, and malignant images. Breast ultrasound images can produce great results in classification, detection, and segmentation of breast cancer when combined with machine learning.

### Data
The data collected at baseline include breast ultrasound images among women in ages between 25 and 75 years old. This data was collected in 2018. The number of patients is 600 female patients. The dataset consists of 780 images with an average image size of 500*500 pixels. The images are in PNG format. The ground truth images are presented with original images. The images are categorized into three classes, which are normal, benign, and malignant.

If you use this dataset, please cite:
Al-Dhabyani W, Gomaa M, Khaled H, Fahmy A. Dataset of breast ultrasound images. Data in Brief. 2020 Feb;28:104863. DOI: 10.1016/j.dib.2019.104863.

## Imports

In [1]:
import os

import pyrootutils

root = pyrootutils.setup_root(
    search_from=os.path.dirname(os.getcwd()),
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)

if os.getenv("DATA_ROOT") is None:
    os.environ["DATA_ROOT"] = f"{root}/data"

In [2]:
import torch
import torch.nn as nn

# Setup device-agnostic code
if torch.cuda.is_available():
    DEVICE = "cuda"  # NVIDIA GPU
    print("GPU Found!!")
else:
    raise Exception("No GPU Found!!")

GPU Found!!


In [3]:
import hydra
from hydra import compose, initialize

In [4]:
# # auto reload dotenv
%load_ext dotenv
%dotenv

# auto reload libs
%load_ext autoreload
%autoreload 2

## Paths setup

In [5]:
from omegaconf import DictConfig, OmegaConf

# Register a resolver for torch dtypes
OmegaConf.register_new_resolver("torch_dtype", lambda name: getattr(torch, name))

with initialize(config_path="../configs", job_name="training_setup", version_base=None):
    cfg: DictConfig = compose(config_name="train.yaml")
    # print(OmegaConf.to_yaml(cfg))
    print(cfg.models)

{'optimizer': {'_target_': 'torch.optim.Adam', '_partial_': True, 'lr': 0.001, 'weight_decay': 0.0, 'betas': [0.9, 0.999]}, 'scheduler': {'func': None}, 'model': {'_target_': 'src.models.vggnet_fcn_segmentation_model.VGGNetFCN16SegmentationModel', 'num_classes': 3}}


In [6]:
os.chdir(root)

## Loading Dataset

In [7]:
data_module = hydra.utils.instantiate(cfg.datamodule)

class_weights = data_module.class_weights
class_names = data_module.classes
num_classes = len(class_names)
class_names, num_classes, class_weights

(['normal', 'malignant', 'benign'], 3, tensor([1.9774, 1.2494, 0.5903]))

In [8]:
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

# # 1. Select a very small subset of your data
train_dataset, val_dataset = data_module.split_and_preprocess_datasets()

subset_indices = list(range(200))  # Select first 10 samples
train_small_dataset = Subset(train_dataset, subset_indices)
val_small_dataset = Subset(val_dataset, subset_indices)
train_dl = DataLoader(train_small_dataset, batch_size=32, shuffle=False)
val_dl = DataLoader(val_small_dataset, batch_size=32, shuffle=False)

In [9]:
for images, targets in train_dl:
    print(images.shape, targets["masks"].shape, targets["labels"].shape)

    print(f"images:{images.dtype}, {images[0].min()}, {images[0].max()}")
    print(
        f'masks {targets["masks"].dtype}, {targets["masks"][0].min()}, {targets["masks"][0].max()}'
    )
    print(
        f'labels {targets["labels"].dtype}, {targets["labels"].min()}, {targets["labels"].max()}'
    )
    break

torch.Size([32, 3, 224, 224]) torch.Size([32, 1, 224, 224]) torch.Size([32])
images:torch.float32, -2.1179039478302, 2.640000104904175
masks torch.uint8, 0, 1
labels torch.int64, 0, 2


In [10]:
for _images, _targets in val_dl:
    print(_images.shape, _targets["masks"].shape, _targets["labels"].shape)

    print(f"images:{_images[0].dtype}, {_images[0].min()}, {_images[0].max()}")
    print(f'masks {_targets["masks"].dtype}, {_targets["masks"].min()}, {_targets["masks"].max()}')
    print(
        f'labels {_targets["labels"].dtype}, {_targets["labels"].min()}, {_targets["labels"].max()}'
    )
    break

torch.Size([32, 3, 224, 224]) torch.Size([32, 1, 224, 224]) torch.Size([32])
images:torch.float32, -2.1179039478302, 2.5354254245758057
masks torch.uint8, 0, 1
labels torch.int64, 0, 2


## Loading and training the FCN8 model 

In [11]:
from src.utils.gpu_utils import DeviceDataLoader, get_default_device, to_device

torch.cuda.empty_cache()
device = get_default_device()

gpu_weights = to_device(class_weights, device)

In [12]:
from src.losses.dice_loss import SoftDiceLoss
from src.models.vggnet_fcn_segmentation_model import VGGNetFCNSegmentationModel

model = VGGNetFCNSegmentationModel(
    cls_num_classes=num_classes,
    seg_num_classes=1,
    vggnet_type="vgg16_bn",
    seg_weight=0.95,
    cls_weight=0.05,
    segmentation_criterion=SoftDiceLoss(),  # nn.CrossEntropyLoss(),
    classification_criterion=nn.CrossEntropyLoss(weight=gpu_weights),
)
model = torch.compile(model)
model

OptimizedModule(
  (_orig_mod): VGGNetFCNSegmentationModel(
    (segmentation_criterion): SoftDiceLoss()
    (classification_criterion): CrossEntropyLoss()
    (cls_auroc): MulticlassAUROC()
    (encoder): VGGNetEncoder(
      (vgg): VGG(
        (features): Sequential(
          (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
          (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (9): ReLU(inplace=True)
        

In [13]:
for images, labels in train_dl:
    # print(images.shape, labels)
    out = model(images)
    # l = labels['labels'][0]
    # print(l , torch.argmax(l))
    print(out["labels"][0], out["labels"].shape)
    print(out["masks"][0])
    print(out["masks"].shape)
    break

tensor([-0.0891, -0.0277,  0.6381], grad_fn=<SelectBackward0>) torch.Size([32, 3])
tensor([[[ 0.2843, -0.3723,  0.5905,  ..., -0.6889, -0.0129, -0.0735],
         [-0.2269, -0.8781, -1.0452,  ..., -0.6614, -0.0808, -0.3464],
         [-0.5564, -1.0220, -0.0271,  ...,  0.4852,  0.0378,  0.5885],
         ...,
         [ 0.9330,  0.5234,  0.6357,  ..., -1.3837, -1.0692, -0.4629],
         [-0.3920,  0.5641, -0.3922,  ...,  0.0981, -0.0103, -0.0813],
         [ 0.1880, -0.6990,  0.5471,  ...,  0.3258, -0.2395, -0.1229]]],
       grad_fn=<SelectBackward0>)
torch.Size([32, 1, 224, 224])


## GPU Training Setup

## Moving data and model into memory

In [14]:
train_dl = DeviceDataLoader(train_dl, device)
val_dl = DeviceDataLoader(val_dl, device)
to_device(model, device)
# train_dl.device

OptimizedModule(
  (_orig_mod): VGGNetFCNSegmentationModel(
    (segmentation_criterion): SoftDiceLoss()
    (classification_criterion): CrossEntropyLoss()
    (cls_auroc): MulticlassAUROC()
    (encoder): VGGNetEncoder(
      (vgg): VGG(
        (features): Sequential(
          (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
          (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (9): ReLU(inplace=True)
        

## Overfiting the model

In [15]:
from src.utils.train_utils import fit

EPOCHS = 100
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
history = fit(
    model=model,
    train_dataloader=train_dl,
    validation_dataloader=val_dl,
    epochs=EPOCHS,
    optimizer=optimizer,
    device_type=device.type,
    dtype=torch.float16,
    reduce_lr_on_plateau=torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, factor=0.1, patience=5
    ),
)

Epoch [0] Validation Results: lr=1e-03, total_loss=5.2513, cls_loss=1.0939, val_cls_loss=4.0182, cls_acc=0.3259, val_cls_acc=0.3170, val_cls_auroc=0.4839, seg_loss=0.7321, val_seg_loss=0.7907, seg_dice=0.1671, val_seg_dice=0.1663
Epoch [1] Validation Results: lr=1e-03, total_loss=4.5840, cls_loss=1.0776, val_cls_loss=2.6581, cls_acc=0.3393, val_cls_acc=0.3393, val_cls_auroc=0.6072, seg_loss=0.6326, val_seg_loss=0.7491, seg_dice=0.2361, val_seg_dice=0.1988
Epoch [2] Validation Results: lr=1e-03, total_loss=4.4180, cls_loss=1.1252, val_cls_loss=3.4555, cls_acc=0.4018, val_cls_acc=0.3214, val_cls_auroc=0.6921, seg_loss=0.6051, val_seg_loss=0.8210, seg_dice=0.2576, val_seg_dice=0.1673
Epoch [3] Validation Results: lr=1e-03, total_loss=3.9550, cls_loss=0.9968, val_cls_loss=0.9575, cls_acc=0.4107, val_cls_acc=0.4643, val_cls_auroc=0.7361, seg_loss=0.5423, val_seg_loss=0.6738, seg_dice=0.2863, val_seg_dice=0.2529
Epoch [4] Validation Results: lr=1e-03, total_loss=3.7471, cls_loss=0.8794, val_

KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt

seg_losses = [x["seg_loss"] for x in history]
seg_dice = [x["seg_dice"] for x in history]

plt.plot(seg_losses, "-bx")
plt.plot(seg_dice, "-rx")

plt.xlabel("epoch")
plt.ylabel("loss")
plt.grid()
plt.legend(["seg_loss", "seg_dice"])
plt.title("Loss vs. NO. of epochs")