In [1]:
import torch
import torch.nn as nn
import numpy as np
import cv2
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import os
import glob
import albumentations as A
import segmentation_models_pytorch as smp
from unet_vanilla import RoadSegmentData, DiceLoss

In [2]:
img_transform = A.Compose(
        [
            A.VerticalFlip(),
            A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.6),
            A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.6),
            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.6),
            A.RandomRotate90(p=0.8),
            ToTensorV2(transpose_mask=True)
        ]
    )

In [3]:
image_names = [img.split('/')[-1].split('\\')[-1] for img in glob.glob("./Vanilla Dataset/training/images/*")]
image_path = './Vanilla Dataset/training/images/'
mask_path = './Vanilla Dataset/training/groundtruth/'
# if os.name == 'nt':
#     image_path = './mass-data/'
#     mask_path = './mass-data/'

train_image_names, validate_image_names = train_test_split(image_names, train_size=0.8)
train_data = RoadSegmentData(train_image_names, image_path, mask_path, img_transform)
validate_data = RoadSegmentData(validate_image_names, image_path, mask_path, img_transform)
print(validate_image_names)

['satimage_90.png', 'satimage_17.png', 'satimage_139.png', 'satimage_112.png', 'satimage_143.png', 'satimage_68.png', 'satimage_93.png', 'satimage_56.png', 'satimage_28.png', 'satimage_114.png', 'satimage_141.png', 'satimage_113.png', 'satimage_23.png', 'satimage_34.png', 'satimage_57.png', 'satimage_9.png', 'satimage_115.png', 'satimage_87.png', 'satimage_61.png', 'satimage_24.png', 'satimage_103.png', 'satimage_25.png', 'satimage_118.png', 'satimage_1.png', 'satimage_39.png', 'satimage_131.png', 'satimage_8.png', 'satimage_96.png', 'satimage_52.png']


In [4]:
# Hyperparameters
learning_rate = 1e-4
batch_size = 1
epochs = 1

In [5]:
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
validate_dataloader = DataLoader(validate_data, batch_size=batch_size)

In [19]:
model = smp.Unet(
    in_channels=3,
    encoder_name='resnet50',
    classes=1,
    activation='sigmoid'
)
model.cuda()
loss_fn = DiceLoss()
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)

In [20]:
def reset_all_weights(model: nn.Module) -> None:
    """
    refs:
        - https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
        - https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
        - https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    """

    @torch.no_grad()
    def weight_reset(m: nn.Module):
        # - check if the current module has reset_parameters & if it's callabed called it on m
        reset_parameters = getattr(m, "reset_parameters", None)
        if callable(reset_parameters):
            m.reset_parameters()

    # Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
    model.apply(fn=weight_reset)


In [21]:
reset_all_weights(model)

In [22]:
for epoch in range(epochs):
        model.train()
        for batch, (x, y_true) in enumerate(train_dataloader):
            x = x.float().cuda()
            y_true = y_true.float().cuda()
            y_pred = model(x)
            assert y_true.numel() == y_pred.numel()
            loss = loss_fn(y_true, y_pred)

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Validation
        validation_loss = 0.0
        model.eval()
        with torch.no_grad():
            for batch, (x, y_true) in enumerate(validate_dataloader):
                x = x.float().cuda()
                y_true = y_true.float().cuda()
                y_pred = model(x)
                validation_loss += loss_fn(y_true, y_pred).item()
        validation_loss /= (batch+1)
        print(f'Epoch {(epoch+1)} Validation Loss: {validation_loss:5.4}')
        if (epoch+1)%10 == 0:
            torch.save(model, f'Model Epoch {epoch+1}')

./Vanilla Dataset/training/images/satimage_120.png


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 26 but got size 25 for tensor number 1 in the list.

In [14]:
model

UnetPlusPlus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential