In [None]:
# https://github.com/qubvel-org/segmentation_models.pytorch/blob/26e8c47b0a97cb9775170881114821c52755897f/examples/cars%20segmentation%20(camvid).ipynb

### 1. Loading Model

In [None]:
%%capture
# Download Data
!git clone https://github.com/alexgkendall/SegNet-Tutorial ./data

In [None]:
import os

DATA_DIR = "../data/CamVid/"

x_train_dir = os.path.join(DATA_DIR, "train")
y_train_dir = os.path.join(DATA_DIR, "trainannot")

x_val_dir = os.path.join(DATA_DIR, "val")
y_val_dir = os.path.join(DATA_DIR, "valannot")

x_test_dir = os.path.join(DATA_DIR, "test")
y_test_dir = os.path.join(DATA_DIR, "testannot")

### 2. Create Dataset and DataLoader

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader

In [None]:
class CarDataset(Dataset):

    CLASSES = [
        "sky",
        "building",
        "pole",
        "road",
        "pavement",
        "tree",
        "signsymbol",
        "fence",
        "car",
        "pedestrian",
        "bicyclist",
        "unlabelled",
    ]
    
    def __init__(self, img_dir: str, annotation_dir: str, classes=None, augmentation=None):
        self.img_ids = os.listdir(img_dir)
        self.images_fps = [os.path.join(img_dir, image_id) for image_id in self.img_ids]
        self.masks_fps = [os.path.join(annotation_dir, image_id) for image_id in self.img_ids]
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        self.augmentation = augmentation

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        image = cv2.imread(self.images_fps[idx])
        # convert BGR -> RGB
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.masks_fps[idx], 0) # (height, width, channels)

        # extract certain classes from mask (e.g. cars)
        masks = [(mask == v) for v in self.class_values]
        mask = np.stack(masks, axis=-1).astype("float")

        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample["image"], sample["mask"]

        return image.transpose(2, 0, 1), mask.transpose(2, 0, 1) # (channels, height, width)


In [None]:
def visualize(**images):
    """Plot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(" ".join(name.split("_")).title())
        if name == "image":
            plt.imshow(image.transpose(1, 2, 0))
        else:
            plt.imshow(image)
    plt.show()

In [None]:

dataset = CarDataset(x_train_dir, y_train_dir, classes=["car"])
# get some sample
image, mask = dataset[20]
visualize(
    image=image,
    cars_mask=mask.squeeze(),
)


In [None]:
type(image)

### 3. Data Augmentation (skipped)

In [None]:
# -- transform to add image padding to divide by 32
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((384, 480)),
])

In [None]:
# -- data augmentation
import albumentations as A

def get_training_augmentation():
    train_transform = [
        A.HorizontalFlip(p=0.5),
        A.ShiftScaleRotate(
            scale_limit=0.5, rotate_limit=0, shift_limit=0.1, p=1, border_mode=0
        ),
        A.PadIfNeeded(min_height=320, min_width=320, always_apply=True),
        A.RandomCrop(height=320, width=320, always_apply=True),
        A.GaussNoise(p=0.2),
        A.Perspective(p=0.5),
        A.OneOf(
            [
                A.CLAHE(p=1),
                A.RandomBrightnessContrast(p=1),
                A.RandomGamma(p=1),
            ],
            p=0.9,
        ),
        A.OneOf(
            [
                A.Sharpen(p=1),
                A.Blur(blur_limit=3, p=1),
                A.MotionBlur(blur_limit=3, p=1),
            ],
            p=0.9,
        ),
        A.OneOf(
            [
                A.RandomBrightnessContrast(p=1),
                A.HueSaturationValue(p=1),
            ],
            p=0.9,
        ),
    ]
    return A.Compose(train_transform)

def get_validation_augmentation():
    """Add paddings to make image shape divisible by 32"""
    test_transform = [
        A.PadIfNeeded(384, 480),
    ]
    return A.Compose(test_transform)


In [None]:
augmented_dataset = CarDataset(x_train_dir, y_train_dir, classes=["car"], augmentation=get_training_augmentation())

for i in range(3):
    image, mask = dataset[i]
    visualize(image=image, mask=mask.squeeze())
    image2, mask2 = augmented_dataset[i]
    visualize(image=image2, mask=mask2.squeeze())

In [None]:
print(type(image2), image2.shape)

In [None]:
train_dataset = CarDataset(x_train_dir, y_train_dir, classes=["car"], augmentation=get_training_augmentation())
val_dataset = CarDataset(x_val_dir, y_val_dir, classes=["car"], augmentation=get_validation_augmentation())
test_dataset = CarDataset(x_test_dir, y_test_dir, classes=["car"], augmentation=get_validation_augmentation())

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=0)

In [None]:
dataiter = iter(val_loader)

In [None]:
images, labels = next(dataiter)

In [None]:
print(images.shape, labels.shape)

### 4. Create and train model

In [None]:
from dotenv import load_dotenv

In [None]:
load_dotenv()
MLFLOW_URI = os.getenv("MLFLOW_TRACKING_URI")

In [None]:
# print(MLFLOW_URI)

In [None]:
import mlflow

mlflow.set_tracking_uri(MLFLOW_URI) # set_registry_uri()

In [None]:
# experiment_name = 'CarSegmentation'
# if not mlflow.get_experiment_by_name(name=experiment_name):
#     mlflow.create_experiment(name=experiment_name)
# experiment = mlflow.get_experiment_by_name(experiment_name)

In [None]:
NUM_EPOCHS = 1
OUT_CLASSES = 1

In [None]:
import segmentation_models_pytorch as smp
import torch

model = smp.Unet(
    encoder_name="resnet34", 
    encoder_weights="imagenet", 
    in_channels=3, 
    classes=OUT_CLASSES
)

In [None]:
model = smp.create_model("FPN", encoder_name="resnext50_32x4d", in_channels=3, out_classes=OUT_CLASSES)

In [None]:
model = smp.create_model(arch="FPN", encoder_name="resnext50_32x4d", encoder_weights="imagenet", in_channels=3, out_classes=OUT_CLASSES)

In [None]:
# -- dummy test
mmax, mmin = 255, 0
x = (mmax - mmin) * torch.rand((1, 3, 320, 320)) + mmin
t = model(x)

In [None]:
# -- dummy test 2
x2 = torch.tensor(image2, dtype=torch.float32).unsqueeze(0)
print(type(image2), image2.shape, x2.shape)
t = model(x2)


In [None]:
from tqdm import tqdm
from datetime import datetime

def evaluate_model(model, validation_dataloader: DataLoader):
    model.eval()
    dice_loss = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
    metrics_dict = {
        'dice_loss': 0.0
    }

    with torch.no_grad():
        for X, y in validation_dataloader:
            X, y = X.to(torch.float32).to(device), y.to(torch.float32).to(device)

            # forward pass
            outputs = model(X)
            dice = dice_loss(outputs, y)
            
            # compute metrics
            metrics_dict['dice_loss'] += dice

        for metric, value in metrics_dict.items():
            metrics_dict[metric] = metrics_dict[metric] / len(validation_dataloader)
            
            
    return metrics_dict


def train(model, num_epochs: int, train_dataloader: DataLoader, val_dataloader: DataLoader, loss_fn, optimizer, device):
    # run_name = f"carseg_fpn_{datetime.now().strftime('%Y-%m-%d_%H:%M')}"
    tags=[]
    mlflow.pytorch.autolog()
    
    # with mlflow.start_run(experiment_id=experiment.experiment_id, run_name=run_name, tags=tags):
    with mlflow.start_run():
        model.train()
    
        losses = []
        for epoch in tqdm(range(num_epochs)):
            for batch_idx, (X, y) in enumerate(train_dataloader):
                X, y = X.to(torch.float32).to(device), y.to(torch.float32).to(device)
    
                outputs = model(X)
                loss = loss_fn(outputs, y)
                losses.append(loss)
                
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
    
                if batch_idx % 10 == 0:
                    print(f"Batch {batch_idx}/{len(train_dataloader)}")
            
            if epoch % 10 == 0:
                # --- Log model checkpoint
                model_info = mlflow.pyfunc.log_model(
                    python_model=model,
                    name=f"checkpoint-epoch-{epoch}",
                )
                
                # model_info = mlflow.pytorch.log_model(
                #     pytorch_model=model, 
                #     name=f"checkpoint-epoch-{epoch}",
                #     step=epoch,
                #     registered_model_name=None,
                #     # input_example=X,
                # )

                # --- Log metrics
                model_metrics = evaluate_model(model, val_dataloader)
                for metric, value in model_metrics.items():
                    mlflow.log_metric(
                        key=metric,
                        value=value,
                        step=epoch,
                        # model_id=model_info.model_id,  # Link metric to specific model
                        # dataset=val_dataloader,
                    )

        return losses            

In [None]:

optimizer = torch.optim.Adam(model.parameters(), lr=2e-4)
loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=True)
device = 'cpu' 

losses = train(model, NUM_EPOCHS, train_loader, val_loader, loss_fn, optimizer, device)

In [None]:
# for batch_idx, (X, y) in enumerate(train_loader):
#     print(X.shape)
#     X = X.to(torch.float32)
#     print(X)
#     t = model(X)
#     print(t.shape)

### 5. Test and validation

In [None]:
# model

### 6. Quantize the model + ONNX

In [None]:
# --- quantize: Post-Training Dynamic: Eager mode vs FX mode; Post-Training Static:

### EAGER MODE
from torch.quantization import quantize_dynamic
from torch import nn
model_quantized = quantize_dynamic(
    model=model, qconfig_spec={nn.Conv2d, nn.Linear}, dtype=torch.qint8, inplace=False
)

### FX MODE
# from torch.quantization import quantize_fx
# qconfig_dict = {"": torch.quantization.default_dynamic_qconfig}  # An empty key denotes the default applied to all modules
# model_prepared = quantize_fx.prepare_fx(model, qconfig_dict, example_inputs=torch.rand((1, 3, 320, 320)))
# model_quantized2 = quantize_fx.convert_fx(model_prepared)

In [None]:
# --- ONNX


### 7. Comparing models performance