# Preamble

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# Data Augmentation
import albumentations as A
# import albumentations.augmentations.functional as F
from albumentations.pytorch import ToTensorV2

# torchvision
import torchvision.transforms.v2 as TF
from torchviz import make_dot

import numpy as np
from PIL import Image
import cv2

# to have a progress bar
from tqdm import tqdm

# To use pretrained segmentation models (implement in PyTorch)
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

# visualization
import matplotlib.pyplot as plt
from pprint import pprint


# OS/File/Path management
import os
import sys

# load my custom Classes/Functions/etc.
from utils.models.unet import UnetScratch
from utils.dataset import SegmentaionDataset
from utils.inferencing import inference_segmentation, img_to_inference_tensor
from utils.visualization import torch_tensor_for_plt, plot_segmentation_inference


# from utils import (
#     load_checkpoint,
#     save_checkpoint,
#     get_loaders,
#     check_accuracy,
#     save_pred_as_imgs
# )

2023-07-06 00:20:28.859098: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Set up GPU use

In [2]:
print('PyTorch ver:', torch.__version__)

# force pytorch to use GPU
# use "model.to(device)" later on to force a model use Cuda GPU
print('Can I use GPU?', torch.cuda.is_available())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Device used for calculation (CPU\Cuda):', device)

PyTorch ver: 2.0.1+cu118
Can I use GPU? True
Device used for calculation (CPU\Cuda): cuda:0


## Set Hyperparameters

In [3]:
device = device
lr = 1e-4
batch_size = 2
epochs = 5
num_workers = 2
image_height = 384
image_width = 640
num_classes = 1
pin_mem = True
load_model = False
load_model_path = os.path.join('models')

# define path(s)
train_img_dir = os.path.join("data", "traincrop", "img")
train_mask_dir = os.path.join("data", "traincrop", "mask")
val_img_dir = os.path.join("data", "valcrop", "img")
val_mask_dir = os.path.join("data", "valcrop", "mask")

# Create Dataset

In [4]:
from utils.dataset import get_loaders


# define transformers (resize, rescale, augmentation, etc.)
train_transform = A.Compose(
    [
        A.Resize(height=image_height, width=image_width),
        A.Rotate(limit=35, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.1),
        # this is for z-score
        # for pretrained models on, for instance, imagenet need other values for mean and std.
        A.Normalize(
            mean=[0.0, 0.0, 0, 0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0
        ),
        # ToTensorV2(),
    ],
    # TODO: this options seems risky
    # is_check_shapes=False,
)

# we don't want TTA, just some resize, normalization, etc.
val_transform = A.Compose(
    [
        A.Resize(height=image_height, width=image_width),
        A.Normalize(
            mean=[0.0, 0.0, 0, 0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0
        ),
        # ToTensorV2(),
    ],
    # TODO: this options seems risky
    # is_check_shapes=False,
)


# Create Datatset by data loaders
train_loader, val_loader = get_loaders(
    train_img_dir,
    train_mask_dir,
    val_img_dir,
    val_mask_dir,
    batch_size,
    train_transform,
    val_transform,
    num_workers,
    pin_mem,
)

In [5]:
print('# train images:', sum(len(batch) for batch in train_loader))
print('# val images:', sum(len(batch) for batch in val_loader))

# train images: 6
# val images: 8


# Define Model(s)

## Import a Pretrained Segmentaion model (e.g., UNET)

In [6]:
# the downloaded model are located in "~/.cache/torch/hub/checkpoints/"
# backbone_model_name = 'resnet152'
backbone_model_name = 'mobilenet_v2'

# Segmentation model is just a PyTorch nn.Module
# model = smp.FPN(
model = smp.Unet(
    encoder_name=backbone_model_name,        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7 
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)

# load preprocessing func of the loaded model, so our data goes through the same transformation
# hint 1: it needs the input img to be in channels_last format
# hint 2: it outputs the image in channels_last format! (e.g., toch.Size[640, 360, 3])
preprocess_input = get_preprocessing_fn(backbone_model_name, pretrained='imagenet')

### Visualize the architecture

In [10]:
# the attribs of the model
# vars(model)


# plot the model arch
# create a dumy channels_first img (m, C, H, W)
x = torch.zeros(1, 3, 224, 224, dtype=torch.float, requires_grad=False)
yhat = model(x)
print(yhat.shape)

# save the architecture of the model as an image
# make_dot(yhat, params=dict(list(model.named_parameters()))).render("./outputs/model_architecture/Unet_pretrained", format="png")

torch.Size([1, 1, 224, 224])


### Test the model with an img

## Load a Segmentation Model (i.e., UNET), written from scratch

In [6]:
model = UnetScratch(in_channels=3, num_classes=num_classes)

# Train

In [7]:
from utils.training import train_model
from utils.metrics import accuracy_segment, dice_segment

metrics = ("accuracy", "dice")
# a dict to map metrics' name to correspoing fn
metrics_fn = {
    "dice": dice_segment,
    "accuracy": accuracy_segment,
}

train_model(
    model,
    train_loader,
    val_loader,
    epochs=epochs,
    lr=lr,
    device=device,
    metrics=metrics,
    metrics_fn=metrics_fn,
)

  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:00<00:00, 14.07it/s]


------------------------ epoch 0's metric(s) (training) -----------------------
accuracy    0.53
dice        1.90
----------------------- epoch 0's metric(s) (validation) ----------------------
val_accuracy:  0.05
val_dice:   0.09
-------------------------------------------------------------------------------


100%|██████████| 3/3 [00:00<00:00, 14.86it/s]


------------------------ epoch 1's metric(s) (training) -----------------------
accuracy    0.72
dice       -10.01
----------------------- epoch 1's metric(s) (validation) ----------------------
val_accuracy:  0.05
val_dice:   0.09
-------------------------------------------------------------------------------


100%|██████████| 3/3 [00:00<00:00, 18.39it/s]


------------------------ epoch 2's metric(s) (training) -----------------------
accuracy    0.86
dice       -2.24
----------------------- epoch 2's metric(s) (validation) ----------------------
val_accuracy:  0.05
val_dice:   0.09
-------------------------------------------------------------------------------


100%|██████████| 3/3 [00:00<00:00, 18.51it/s]


------------------------ epoch 3's metric(s) (training) -----------------------
accuracy    0.93
dice       -1.28
----------------------- epoch 3's metric(s) (validation) ----------------------
val_accuracy:  0.24
val_dice:   0.08
-------------------------------------------------------------------------------


100%|██████████| 3/3 [00:00<00:00, 17.58it/s]


------------------------ epoch 4's metric(s) (training) -----------------------
accuracy    0.97
dice       -0.85
----------------------- epoch 4's metric(s) (validation) ----------------------
val_accuracy:  0.93
val_dice:   0.01
-------------------------------------------------------------------------------
----------------------- Saving Checkpoint (In progress) -----------------------

Checkpoint was saved as: ./models/2023.07.06@00-20-53-model_checkpoint.pth.tar

--------------------------- Saving Checkpoint (Done) --------------------------


# Inference

In [None]:
img_addr = os.path.join('images', '') + '20160222_081011_1_721.jpg'
mask_addr = os.path.join('images', '') + '20160222_081011_1_721_mask.png'

# inference_semantic_seg(img_addr, model=model, thresh=0.5, )

img = Image.open(img_addr)
mask = Image.open(mask_addr)

# normalize [-1, 1] the PIL image as a torch.Tensor
img_tensor = (TF.functional.pil_to_tensor(img) / 127.5 - 1)

yhat = model(img_tensor.unsqueeze(dim=0))
yhat_c_last = yhat[0].moveaxis(0, -1).detach().numpy()
thresh = 0.5
mask_batch = np.where(yhat_c_last >= thresh, 255, 0)

fig, axes = plt.subplots(1, 3, figsize=(30, 10))
axes[0].imshow(img)
axes[0].set_title('Image', fontsize=16)
axes[1].imshow(mask_batch, cmap='gray')
axes[1].set_title('Mask (predicted)', fontsize=16)
axes[2].imshow(mask, cmap='gray')
axes[2].set_title('Mask (target)', fontsize=16)
plt.show()

In [16]:
img_addr = os.path.join('images', '') + '20160222_081011_1_721.jpg'
mask_addr = os.path.join('images', '') + '20160222_081011_1_721_mask.png'

# only for one image, if multiple use utils.dataset
test_img_batch = img_to_inference_tensor(img_addr, size=(image_height, image_width))
test_mask_batch = img_to_inference_tensor(mask_addr, size=(image_height, image_width))

# make an inference
yhat_mask = inference_segmentation(test_img_batch, model=model, normalize=True)

# make torch.Tensors ready for pyplot
test_img_batch = torch_tensor_for_plt(test_img_batch)
test_mask_batch = torch_tensor_for_plt(test_mask_batch)
yhat_mask = torch_tensor_for_plt(yhat_mask)

# plot the inference
plot_segmentation_inference(test_img_batch, test_mask_batch, yhat_mask)

RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

# 