In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import os
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2

In [2]:
import warnings

warnings.filterwarnings("ignore")

In [3]:
sys.path.insert(
    0, os.path.abspath(os.path.join(os.getcwd(), "../../DermSynth3D_private"))
)

from dermsynth3d.datasets.datasets import (
    ImageDataset,
    SynthDataset_Detection,
    RealDataset_Detection,
)
from dermsynth3d.models.model import faster_rcnn_texture_model
from dermsynth3d.utils.evaluate_detection import (
    evaluate_detection,
)
from dermsynth3d.utils.utils import (
    MetricLogger,
    SmoothedValue,
    warmup_lr_scheduler,
    reduce_dict,
)

In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
img_size = (256, 256)

In [6]:
det_model = faster_rcnn_texture_model(
    device,
    num_classes=2,
    max_input_size=img_size[0],
    pretrained_backbone=True,
)

In [8]:
det_model = det_model.to(device)
det_model.train()
params = [p for p in det_model.parameters() if p.requires_grad]

### Prepare the train and validation data

In [10]:
# Assumes the model was pretrained using these values.
preprocess_input = A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
img_preprocess = A.Compose(
    [
        preprocess_input,
    ]
)

# To force a resize of the input image.
resize_func = A.Resize(height=img_size[0], width=img_size[1])

# Perform spatial augmentation on both the image and mask.
spatial_augment = A.Compose(
    [
        A.HorizontalFlip(),
        A.RandomRotate90(),
        resize_func,
    ]
)

# Different types of augmentations on the image.
min_v = 0.8
max_v = 1.2
img_augment = A.Compose(
    [
        A.ColorJitter(
            brightness=(min_v, max_v),
            contrast=(min_v, max_v),
            saturation=(min_v, max_v),
            hue=(-0.025, 0.025),
        ),
        A.ISONoise(color_shift=(0.01, 0.1), intensity=(0.1, 0.75), always_apply=False),
        A.GaussianBlur(blur_limit=(3, 3)),
        A.ImageCompression(10, 100),
    ]
)

In [19]:
# Point to the synthetic data you created.
dir_images = "../data/all_data/images"
dir_targets = "../data/all_data/targets"

synth_ds = SynthDataset_Detection(
    dir_images=dir_images,
    dir_targets=dir_targets,
    name="synth_train",
    spatial_transform=spatial_augment,
    image_augment=img_augment,
    image_preprocess=preprocess_input,
    target_preprocess=None,
    target_extension=".npz",
    totensor=ToTensorV2(transpose_mask=True),
)

print(len(synth_ds))

10000


In [18]:
val_spatial_augment = A.Compose(
    [
        resize_func,
    ]
)


real_val_ds = RealDataset_Detection(
    dir_images="../data/FUSeg/validation/images",
    dir_targets="../data/FUSeg/validation/labels",
    name="real_val",
    image_extension=".png",
    target_extension=".png",
    image_augment=None,
    spatial_transform=val_spatial_augment,
    image_preprocess=img_preprocess,
    totensor=ToTensorV2(transpose_mask=True),
)


print(len(real_val_ds))

200


### Train the model

In [15]:
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    metric_logger.add_meter("lr", SmoothedValue(window_size=1, fmt="{value:.6f}"))
    header = "Epoch: [{}]".format(epoch)

    lr_scheduler = None
    if epoch == 0:
        warmup_factor = 1.0 / 1000
        warmup_iters = min(1000, len(data_loader) - 1)

        lr_scheduler = warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)

    for _, _, images, masks, targets in metric_logger.log_every(
        data_loader, print_freq, header
    ):
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum(loss for loss in loss_dict.values())

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = reduce_dict(loss_dict)
        losses_reduced = sum(loss for loss in loss_dict_reduced.values())

        loss_value = losses_reduced.item()

        if not math.isfinite(loss_value):
            print("Loss is {}, stopping training".format(loss_value))
            print(loss_dict_reduced)
            sys.exit(1)

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        if lr_scheduler is not None:
            lr_scheduler.step()

        metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        metric_logger.update(lr=optimizer.param_groups[0]["lr"])

    return metric_logger

In [16]:
def collate_fn(batch):
    return tuple(zip(*batch))

In [20]:
# DataLoader for training the Faster RCNN.
batch_size = 8
train_dataloader = DataLoader(
    synth_ds,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn,
)

batch_size = 8
val_dataloader = DataLoader(
    real_val_ds,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True,
    collate_fn=collate_fn,
)

In [21]:
num_epochs = 20
lr = 0.001
momentum = 0.95
weight_decay = 0.0005

In [22]:
optimizer = torch.optim.SGD(
    params,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
)

In [None]:
training_stats = []
max_valid_iou = 0
training_epoch = 0

dir_save_models = "/path/to/save/model"
if not os.path.isdir(dir_save_models):
    os.mkdir(dir_save_models)

# Turn on the training mode for the model.
for epoch in range(num_epochs):
    det_model.train()
    log = train_one_epoch(
        det_model,
        optimizer,
        train_dataloader,
        device,
        epoch,
        print_freq=20,
    )
    det_model.eval()
    centroid_result, iou_result = evaluate_detection(det_model, val_dataloader, device)

    ap = pd.DataFrame(centroid_result).ap.mean()
    iou = pd.DataFrame(iou_result).iou.mean()
    print("centroid_result, ap:", ap)
    print("iou_result, iou:", iou)

    if iou > max_valid_iou:
        max_valid_iou = iou
        # Save model to disk
        print("Saving model to disk. iou={}".format(iou))
        torch.save(
            det_model.state_dict(),
            os.path.join(
                dir_save_models, "model_state_dict_patches" + str(training_epoch)
            ),
        )
        torch.save(
            det_model,
            os.path.join(dir_save_models, "model_patches" + str(training_epoch)),
        )

    training_stats.append(
        {
            "epoch": training_epoch,
            "valid_iou": iou,
            "valid_ap": ap,
            "train_loss": log.meters["loss"].value,
        }
    )
    training_epoch = training_epoch + 1