In [None]:
import math
import os
import re
import cv2
import numpy as np
import matplotlib.pyplot as plt
import random
import tensorflow as tf
import copy
from glob import glob
from random import sample
from PIL import Image

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# !unzip /content/drive/MyDrive/cil-road-segmentation-2021.zip -d /content/drive/MyDrive/cil-road-segmentation-2021

In [None]:
# mean and std of training set
mean1 = np.array([0.330, 0.327, 0.293])
std1 = np.array([0.183, 0.176, 0.175])

# flip the image horizaontally with probability = prob
def horizontal_flip(train_img, label_img, prob=0.75):
    rdn = np.random.random()
    if rdn < prob:
        return cv2.flip(train_img, 1), cv2.flip(label_img, 1)
    else:
        return train_img, label_img


# flip the image vertically with probability = prob
def vertical_flip(train_img, label_img, prob=0.75):
    rdn = np.random.random()
    if rdn < prob:
        return cv2.flip(train_img, 0), cv2.flip(label_img, 0)
    else:
        return train_img, label_img


# rotate the image by k*90 degree with probability = prob
def rotate_90s(train_img, label_img, prob=0.75):
    rdn = np.random.random()
    if rdn < prob:
        # 1<= k <= 3, rotate clockwise by 90/180/270 degree
        k = np.random.randint(low=1, high=4, size=1)[0]
        return np.rot90(train_img, k), np.rot90(label_img, k)
    else:
        return train_img, label_img


# adjust the hue of an RGB image by random factor in [-10, 10] with probability = prob
def hue_image(train_img, label_img, min_hue_factor=-10, max_hue_factor=10, prob=0.75):
    rdn = np.random.random()
    if rdn < prob:
        hsv = cv2.cvtColor(train_img, cv2.COLOR_RGB2HSV)
        h, s, v = cv2.split(hsv)
        delta = np.random.randint(low=min_hue_factor, high=max_hue_factor, size=1)[0]
        h = np.clip(h+delta, 0, 180).astype(h.dtype)
        final_hsv = cv2.merge((h, s, v))
        new_train_img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2RGB)
        return new_train_img, label_img
    else:
        return train_img, label_img


# adjust the saturation of an RGB image by random factor in [-20, 20] with probability = prob
def saturation_image(train_img, label_img, min_saturation_factor=-20, max_saturation_factor=20, prob=0.75):
    rdn = np.random.random()
    if rdn < prob:
        hsv = cv2.cvtColor(train_img, cv2.COLOR_RGB2HSV)
        h, s, v = cv2.split(hsv)
        delta = np.random.randint(low=min_saturation_factor, high=max_saturation_factor, size=1)[0]
        s = np.clip(s+delta, 0, 255).astype(h.dtype)
        final_hsv = cv2.merge((h, s, v))
        # print(h.shape)
        # print(v.shape)
        # print(s.shape)
        new_train_img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2RGB)
        return new_train_img, label_img
    else:
        return train_img, label_img


# adjust the brightness of an RGB image by random delta in [-30, 30] with probability = prob
def brightness_image(train_img, label_img, min_brightness_factor=-30, max_brightness_factor=30, prob=0.75):
    rdn = np.random.random()
    if rdn < prob:
        brightness_factor = np.random.uniform(min_brightness_factor, max_brightness_factor, 1)[0]
        new_train_img = _contrast_and_brightness(train_img, 1, brightness_factor)
        return new_train_img, label_img
    else:
        return train_img, label_img


# adjust the contrast of an RGB image by random factor in [1, 1.5] with probability = prob
def contrast_image(train_img, label_img, min_contrast_factor=1, max_contrast_factor=1.5, prob=0.75):
    rdn = np.random.random()
    if rdn < prob:
        contrast_factor = np.random.uniform(min_contrast_factor, max_contrast_factor, 1)[0]
        new_train_img = _contrast_and_brightness(train_img, contrast_factor, 0)
        return new_train_img, label_img
    else:
        return train_img, label_img


def _contrast_and_brightness(img, contrast_factor, brightness_factor):
    blank = np.zeros(img.shape, img.dtype)
    dst = cv2.addWeighted(img, contrast_factor, blank, 1-contrast_factor, brightness_factor)
    return dst


def random_scale(train_img, label_img, pad_reflect=False, probs=[0.7, 0.9, 1]):
    rdn = np.random.random()
    if rdn < probs[0]:
        return _crop_and_scale_up(train_img, label_img)
    elif rdn < probs[1]:
        return _random_shift(train_img, label_img)
    else:
        return _shrink_and_pad(train_img, label_img, pad_reflect)


def _crop_and_scale_up(train_img, label_img, crop_size=[(200, 200), (250, 250), (300, 300), (350, 350)]):
    original_shape = train_img.shape[:2]

    # cropping
    random_crop_shape = random.choice(crop_size)
    train_img, label_img = _random_crop(train_img, label_img, random_crop_shape)

    # scalse up
    train_img = cv2.resize(train_img, original_shape, interpolation=cv2.INTER_LINEAR)
    label_img = cv2.resize(label_img, original_shape, interpolation=cv2.INTER_LINEAR)

    return train_img, label_img


def _random_crop(train_img, label_img, crop_shape):
    original_shape = train_img.shape[:2]

    crop_h = original_shape[0]-crop_shape[0]
    crop_w = original_shape[1]-crop_shape[1]
    nh = random.randint(0, crop_h)
    nw = random.randint(0, crop_w)
    train_crop = train_img[nh:nh + crop_shape[0], nw:nw + crop_shape[1]]
    label_crop = label_img[nh:nh + crop_shape[0], nw:nw + crop_shape[1]]
    return train_crop, label_crop


def _random_shift(train_img, label_img):
    original_shape = train_img.shape[:2]
    max_translation = np.multiply(0.15, original_shape).astype(np.int64)

    delta_h = random.randint(-max_translation[0], max_translation[0])
    delta_w = random.randint(-max_translation[1], max_translation[1])

    train_img = _shift(_shift(train_img, delta_h, height=True), delta_w, height=False)
    label_img = _shift(_shift(label_img, delta_h, height=True), delta_w, height=False)

    return train_img, label_img


def _shift(img, delta, height):
    if delta == 0:
        return img
    translated_img = np.empty_like(img)
    if height:
        if delta >= 0:
            translated_img[:delta] = 0
            translated_img[delta:] = img[:-delta]
        elif delta < 0:
            translated_img[:delta] = img[-delta:]
            translated_img[delta:] = 0
        return translated_img
    else:
        if delta >= 0:
            translated_img[:, :delta] = 0
            translated_img[:, delta:] = img[:, :-delta]
        elif delta < 0:
            translated_img[:, :delta] = img[:, -delta:]
            translated_img[:, delta:] = 0
        return translated_img


def _shrink_and_pad(train_img, label_img, pad_reflect, shrink_range=(0.6, 0.95)):
    original_shape = train_img.shape[:2]

    random_ratio = np.random.uniform(shrink_range[0], shrink_range[1])
    train_img, label_img = _random_shrink(train_img, label_img, random_ratio)
    train_img, label_img = _random_pad(train_img, label_img, original_shape, pad_reflect)

    return train_img, label_img


def _random_shrink(train_img, label_img, ratio):
    original_shape = train_img.shape[:2]
    shrink_shape = (int(original_shape[0]*ratio), int(original_shape[1]*ratio))
    # shrink
    train_img = cv2.resize(train_img, shrink_shape, interpolation=cv2.INTER_LINEAR)
    label_img = cv2.resize(label_img, shrink_shape, interpolation=cv2.INTER_LINEAR)
    return train_img, label_img


def _random_pad(train_img, label_img, target_shape, pad_reflect):
    original_shape = train_img.shape[:2]

    # put to center and padding
    margin = np.subtract(target_shape, original_shape)

    # random translation: limited by max_ratio and remained margin
    max_translation = np.multiply(0.15, original_shape)
    max_translation = np.minimum((margin // 2), max_translation)
    max_translation = max_translation.astype(np.int64)

    # place image with random translation
    pad_top = margin[0] // 2 + random.randint(-max_translation[0], max_translation[0])
    pad_left = margin[1] // 2 + random.randint(-max_translation[1], max_translation[1])
    pad_bottom = margin[0] - pad_top
    pad_right = margin[1] - pad_left

    # padding to original size
    if pad_reflect:
        train_img = cv2.copyMakeBorder(train_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_REFLECT)
        label_img = cv2.copyMakeBorder(label_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_REFLECT)
    else:
        train_img = cv2.copyMakeBorder(train_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)
        label_img = cv2.copyMakeBorder(label_img, pad_top, pad_bottom, pad_left, pad_right, cv2.BORDER_CONSTANT, value=0)

    return train_img, label_img


# rotate the image by a minor degree in [+25, -25] with probability = prob
def random_rotate(train_img, label_img, min_angle=-25, max_angle=25, prob=0.75):
    rdn = np.random.random()
    if rdn < prob:
        random_angle = np.random.uniform(min_angle, max_angle, 1)[0]
        return _rotate_image(train_img, random_angle), _rotate_image(label_img, random_angle)
    else:
        return train_img, label_img
        # return tf.convert_to_tensor(train_img), tf.convert_to_tensor(label_img)


def _rotate_image(img, angle):
    if -1 < angle < 1:
        return img
    shape_2d = (img.shape[1], img.shape[0])
    center_2d = (img.shape[1] / 2, img.shape[0] / 2)
    rotation_matrix = cv2.getRotationMatrix2D(center_2d, angle, 1.0)
    img = cv2.warpAffine(img, rotation_matrix, shape_2d, flags=cv2.INTER_LINEAR)
    return img


def normalize(img):
    img = img.astype(np.float32) / 255.0
    img = img - mean1
    img = img / std1
    return img


def discretize(gt, threshold=40):
    # The order matters
    gt[gt < threshold] = 0
    gt[gt >= threshold] = 1
    return gt


def get_edge_mask(image):
    """ Accept image before binarization """
    edge_mask = cv2.Canny(image, 0, 255)
    edge_mask[image < 40] = 0
    edge_mask[edge_mask != 0] = 1
    return edge_mask

In [None]:
# from google.colab.patches import cv2_imshow

# train_img = cv2.imread('/content/drive/MyDrive/cil-project/cil-road-segmentation-2021/training/images/satImage_001.png')
# label_img = cv2.imread('/content/drive/MyDrive/cil-project/cil-road-segmentation-2021/training/groundtruth/satImage_001.png')
# imgs = np.hstack([train_img,label_img])
# cv2_imshow(imgs)

In [None]:
# !unzip /content/drive/MyDrive/cil-project/cil-road-segmentation-2021.zip -d /content/drive/MyDrive/cil-project/cil-road-segmentation-2021

In [None]:
# some constants
PATCH_SIZE = 16  # pixels per side of square patches
VAL_SIZE = 275  # size of the validation set (number of images)(nearly 20 percent of the training set)
CUTOFF = 0.25  # minimum average brightness for a mask patch to be classified as containing road

In [None]:
!rm -rf /content/validation
!rm -rf /content/training

In [None]:
# unzip the dataset, split it and organize it in folders
if not os.path.isdir('validation'):  # make sure this has not been executed yet
  try:
          !unzip /content/drive/MyDrive/cil-project/cil-road-segmentation-2021.zip -d cil-road-segmentation-2021
          !mkdir training
          !mkdir training/images
          !mkdir training/groundtruth
          !mv /content/cil-road-segmentation-2021/training/training/* training
          # !cp -r /content/cil-road-segmentation-2021/training/training/* training
          # put the additional images into the training set
          !cp -r /content/drive/MyDrive/cil-project/additional_images/images/* training/images
          !cp -r /content/drive/MyDrive/cil-project/additional_images/groundtruth/* training/groundtruth
          # !rm -rf /content/cil-road-segmentation-2021/training/training
          !mkdir validation
          !mkdir validation/images
          !mkdir validation/groundtruth
          for img in sample(glob("training/images/*.png"), VAL_SIZE):
            os.rename(img, img.replace('training', 'validation'))
            mask = img.replace('images', 'groundtruth')
            os.rename(mask, mask.replace('training', 'validation'))
  except:
      print('Please upload a .zip file containing your datasets.')

In [None]:
def load_all_from_path_255(path):
    # loads all HxW .pngs contained in path as a 4D np.array of shape (n_images, H, W, 3)
    # images are loaded as floats with values in the interval [0., 1.]
    return np.stack([np.array(Image.open(f)) for f in sorted(glob(path + '/*.png'))]).astype(np.float32)

def load_all_from_path(path):
    # loads all HxW .pngs contained in path as a 4D np.array of shape (n_images, H, W, 3)
    # images are loaded as floats with values in the interval [0., 1.]
    return np.stack([np.array(Image.open(f)) for f in sorted(glob(path + '/*.png'))]).astype(np.float32) / 255.

In [None]:
# paths to training and validation datasets
train_path = '/content/training'
val_path = '/content/validation'

train_images = load_all_from_path_255(os.path.join(train_path, 'images'))
train_masks = load_all_from_path_255(os.path.join(train_path, 'groundtruth'))
val_images = load_all_from_path_255(os.path.join(val_path, 'images'))
val_masks = load_all_from_path_255(os.path.join(val_path, 'groundtruth'))

In [None]:
print(train_images.shape)

(80, 400, 400, 3)


In [None]:
!mkdir '/content/training/new_images'
!mkdir '/content/training/new_groundtruth'
!mkdir '/content/validation/new_images'
!mkdir '/content/validation/new_groundtruth'
# !rm -rf '/content/training/new_images'
# !rm -rf '/content/training/new_groundtruth'

In [None]:
import tensorflow.compat.v1 as tf
from google.colab.patches import cv2_imshow
def preprocess_saveimages(train_images, label_images, path_suffix, name_suffix):
  for train_image, label_image in zip(train_images, label_images):
    train_image, label_image = horizontal_flip(train_image, label_image)
    train_image, label_image = vertical_flip(train_image, label_image)
    train_image, label_image = rotate_90s(train_image, label_image)
    train_image, label_image = random_rotate(train_image, label_image)
    train_image, label_image = random_scale(train_image, label_image)
    train_image, label_image = contrast_image(train_image, label_image)
    train_image, label_image = hue_image(train_image, label_image)
    train_image, label_image = saturation_image(train_image, label_image)
    train_image, label_image = brightness_image(train_image, label_image)
    save_img_path = path_suffix + '/images/'
    save_label_path = path_suffix + '/groundtruth/'
    cv2.imwrite(save_img_path + name_suffix + str(cnt) + '.png', train_image)
    cv2.imwrite(save_label_path + name_suffix + str(cnt) + '.png', label_image)

In [None]:
preprocess_saveimages(train_images, train_masks, '/content/training', 'st')
# preprocess_saveimages(train_images, train_masks, '/content/training', 'nd')
preprocess_saveimages(val_images, val_masks, '/content/validation', 'st')
# preprocess_saveimages(val_images, val_masks, '/content/validation', 'nd')

In [None]:
# !mv /content/training/new_images/* /content/training/images
# !mv /content/training/new_groundtruth/* /content/training/groundtruth
# !mv /content/validation/new_images/* /content/validation/images
# !mv /content/validation/new_groundtruth/* /content/validation/groundtruth

In [None]:
# !rm -rf '/content/training/new_images'
# !rm -rf '/content/training/new_groundtruth'
# !rm -rf '/content/validation/new_images'
# !rm -rf '/content/validation/new_groundtruth'

In [None]:
# import tensorflow.compat.v1 as tf
# from google.colab.patches import cv2_imshow
# def pre_process_and_save_images(train_images, label_images):
#   transformed_train_images = np.zeros(train_images.shape, dtype=np.float32)
#   transformed_label_images = np.zeros(label_images.shape, dtype=np.float32)
#   cnt = 0
#   for train_image, label_image in zip(train_images, label_images):
#       # train_image, label_image = horizontal_flip(train_image, label_image)
#       # train_image, label_image = vertical_flip(train_image, label_image)
#       # train_image, label_image = rotate_90s(train_image, label_image)
#       train_image, label_image = random_rotate(train_image, label_image)
#       # train_image, label_image = random_scale(train_image, label_image)
#       # train_image, label_image = hue_image(train_image, label_image)
#       # train_image, label_image = saturation_image(train_image, label_image)
#       # train_image, label_image = brightness_image(train_image, label_image)
#       # train_image, label_image = contrast_image(train_image, label_image)
#       transformed_train_images[cnt] = train_image
#       transformed_label_images[cnt] = label_image
#       cnt += 1
#   return np.vstack((train_images, transformed_train_images)), np.vstack((label_images, transformed_label_images))

In [None]:
# train_images, train_masks = pre_process_and_save_images(train_images, train_masks)
# # val_images_0, val_masks_0 = pre_process_and_save_images(val_images, val_masks)
# val_images, val_masks = pre_process_and_save_images(val_images, val_masks)
# print(val_images.shape)

In [None]:
# val_images, val_masks = pre_process_and_save_images(val_images, val_masks)
# print(val_images.shape)

In [None]:
# folder_name = '/content/drive/MyDrive/cil-project/cil-road-segmentation-2021/original_random_rotate'
# num_val = str(val_images.shape[0])
# # !mkdir $folder_name
# # np.save(folder_name+'/train_images.npy',train_images)
# # np.save(folder_name+'/train_masks.npy',train_masks)
# np.save(folder_name+'/val_images.npy',val_images)
# np.save(folder_name+'/val_masks.npy',val_masks)
# # !mkdir $folder_name/model
# # !mkdir $folder_name/predict

In [None]:
# folder_name = '/content/drive/MyDrive/cil-project/cil-road-segmentation-2021/original_random_rotate'
# num_val = str(val_images.shape[0])
# # !rm -rf $folder_name
# !mkdir $folder_name
# np.save(folder_name+'/train_images.npy',train_images)
# np.save(folder_name+'/train_masks.npy',train_masks)
# np.save(folder_name+'/val_images.npy',val_images)
# np.save(folder_name+'/val_masks.npy',val_masks)
# !mkdir $folder_name/model
# !mkdir $folder_name/predict

In [None]:
# def from_array_to_pictures(pictures_array, path_suffix):
#   cnt = 0
#   for a in pictures_array:
#     path = folder_name+path_suffix
#     cv2.imwrite(path + str(cnt) + '.png', a)
#     cnt += 1

In [None]:
# # !rm -rf $folder_name/training
# # !rm -rf $folder_name/training/images
# # !rm -rf $folder_name/training/groundtruth
# !rm -rf $folder_name/validation
# # !rm -rf $folder_name/validation/images
# # !rm -rf $folder_name/validation/groundtruth
# # !mkdir $folder_name/training
# # !mkdir $folder_name/training/images
# # !mkdir $folder_name/training/groundtruth
# !mkdir $folder_name/validation
# !mkdir $folder_name/validation/images
# !mkdir $folder_name/validation/groundtruth

In [None]:
# # from_array_to_pictures(train_images, '/training/images/')
# # from_array_to_pictures(train_masks, '/training/groundtruth/')
# from_array_to_pictures(val_images, '/validation/images/')
# from_array_to_pictures(val_masks, '/validation/groundtruth/')

In [None]:
pip install tensorboard

In [None]:
# import tensorflow.compat.v1 as tf
# from google.colab.patches import cv2_imshow
# def pre_process_images(train_images, label_images):
#   transformed_train_images = np.zeros(train_images.shape, dtype=np.float32)
#   transformed_label_images = np.zeros(label_images.shape, dtype=np.float32)
#   cnt = 0
#   with tf.Session() as sess:
#     for train_image, label_image in zip(train_images, label_images):
#       # train_image, label_image = horizontal_flip(train_image, label_image)
#       # train_image, label_image = vertical_flip(train_image, label_image)
#       # train_image, label_image = rotate_90s(train_image, label_image)
#       # train_image, label_image = random_rotate(train_image, label_image)
#       train_image, label_image = random_scale(train_image, label_image)
#       # train_image, label_image = hue_image(train_image, label_image)
#       # train_image, label_image = saturation_image(train_image, label_image)
#       # train_image, label_image = brightness_image(train_image, label_image)
#       # train_image, label_image = contrast_image(train_image, label_image)
#       transformed_train_images[cnt] = train_image
#       transformed_label_images[cnt] = label_image
#       cnt += 1
#   return transformed_train_images, transformed_label_images

In [None]:
# def squeeze_patches(images):
#   new_train_image = np.empty([50000, 3], dtype=float)
#   out_cnt = 0
#   for img in train_image_patches:
#     sum = [0, 0, 0]
#     cnt = 0
#     for i in range(16):
#       for j in range(16):
#         sum[0] += img[i][j][0]
#         sum[1] += img[i][j][1]
#         sum[2] += img[i][j][2]
#         cnt += 1   
#     sum[0] /= cnt
#     sum[1] /= cnt
#     sum[2] /= cnt
#     new_train_image[out_cnt] = sum
#     out_cnt += 1 
#   return new_train_image

In [None]:
def image_to_patches(images, masks=None):
    # takes in a 4D np.array containing images and (optionally) a 4D np.array containing the segmentation masks
    # returns a 4D np.array with an ordered sequence of patches extracted from the image and (optionally) a np.array containing labels
    n_images = images.shape[0]  # number of images
    h, w = images.shape[1:3]  # shape of images
    assert (h % PATCH_SIZE) + (w % PATCH_SIZE) == 0  # make sure images can be patched exactly

    h_patches = h // PATCH_SIZE
    w_patches = w // PATCH_SIZE
    patches = images.reshape((n_images, h_patches, PATCH_SIZE, h_patches, PATCH_SIZE, -1))
    patches = np.moveaxis(patches, 2, 3)
    patches = patches.reshape(-1, PATCH_SIZE, PATCH_SIZE, 3)
    if masks is None:
        return patches

    masks = masks.reshape((n_images, h_patches, PATCH_SIZE, h_patches, PATCH_SIZE, -1))
    masks = np.moveaxis(masks, 2, 3)
    labels = np.mean(masks, (-1, -2, -3)) > CUTOFF  # compute labels
    labels = labels.reshape(-1).astype(np.float32)
    shape = masks.shape

    #patches = squeeze_patches(patches)
    
    return patches, labels


In [None]:
import torch
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from tqdm.notebook import tqdm

def np_to_tensor(x, device):
    # allocates tensors from np.arrays
    if device == 'cpu':
        return torch.from_numpy(x).cpu()
    else:
        return torch.from_numpy(x).contiguous().pin_memory().to(device=device, non_blocking=True)

def accuracy_fn(y_hat, y):
    # computes classification accuracy
    return (y_hat.round() == y.round()).float().mean()

class ImageDataset(torch.utils.data.Dataset):
    # dataset class that deals with loading the data and making it available by index.

    def __init__(self, path, device, use_patches=True, resize_to=(400, 400)):
        self.path = path
        self.device = device
        self.use_patches = use_patches
        self.resize_to=resize_to
        self.x, self.y, self.n_samples = None, None, None
        self._load_data()

    def _load_data(self):  # not very scalable, but good enough for now
        self.x = load_all_from_path(os.path.join(self.path, 'images'))
        self.y = load_all_from_path(os.path.join(self.path, 'groundtruth'))
        # self.x, self.y = pre_process_images(self.x, self.y)
        if self.use_patches:  # split each image into patches
            self.x, self.y = image_to_patches(self.x, self.y)
        elif self.resize_to != (self.x.shape[1], self.x.shape[2]):  # resize images
            self.x = np.stack([cv2.resize(img, dsize=self.resize_to) for img in self.x], 0)
            self.y = np.stack([cv2.resize(mask, dsize=self.resize_to) for mask in self.y], 0)
        self.x = np.moveaxis(self.x, -1, 1)  # pytorch works with CHW format instead of HWC
        self.n_samples = len(self.x)

    def _preprocess(self, x, y):
        # to keep things simple we will not apply transformations to each sample,
        # but it would be a very good idea to look into preprocessing
        return x, y

    def __getitem__(self, item):
        return self._preprocess(np_to_tensor(self.x[item], self.device), np_to_tensor(self.y[[item]], self.device))
    
    def __len__(self):
        return self.n_samples

In [None]:
class Block(nn.Module):
    # a repeating structure composed of two convolutional layers with batch normalization and ReLU activations
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=3, padding=1),
                                   nn.ReLU(),
                                   nn.BatchNorm2d(out_ch),
                                   nn.Conv2d(in_channels=out_ch, out_channels=out_ch, kernel_size=3, padding=1),
                                   nn.ReLU())

    def forward(self, x):
        return self.block(x)

        
class UNet(nn.Module):
    # UNet-like architecture for single class semantic segmentation.
    def __init__(self, chs=(3,64,128,256,512,1024)):
        super().__init__()
        enc_chs = chs  # number of channels in the encoder
        dec_chs = chs[::-1][:-1]  # number of channels in the decoder
        self.enc_blocks = nn.ModuleList([Block(in_ch, out_ch) for in_ch, out_ch in zip(enc_chs[:-1], enc_chs[1:])])  # encoder blocks
        self.pool = nn.MaxPool2d(2)  # pooling layer (can be reused as it will not be trained)
        self.upconvs = nn.ModuleList([nn.ConvTranspose2d(in_ch, out_ch, 2, 2) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # deconvolution
        self.dec_blocks = nn.ModuleList([Block(in_ch, out_ch) for in_ch, out_ch in zip(dec_chs[:-1], dec_chs[1:])])  # decoder blocks
        self.head = nn.Sequential(nn.Conv2d(dec_chs[-1], 1, 1), nn.Sigmoid()) # 1x1 convolution for producing the output

    def forward(self, x):
        # encode
        enc_features = []
        for block in self.enc_blocks[:-1]:
            x = block(x)  # pass through the block
            enc_features.append(x)  # save features for skip connections
            x = self.pool(x)  # decrease resolution
        x = self.enc_blocks[-1](x)
        # decode
        for block, upconv, feature in zip(self.dec_blocks, self.upconvs, enc_features[::-1]):
            x = upconv(x)  # increase resolution
            x = torch.cat([x, feature], dim=1)  # concatenate skip features
            x = block(x)  # pass through the block
        return self.head(x)  # reduce to 1 channel


def patch_accuracy_fn(y_hat, y):
    # computes accuracy weighted by patches (metric used on Kaggle for evaluation)
    h_patches = y.shape[-2] // PATCH_SIZE
    w_patches = y.shape[-1] // PATCH_SIZE
    patches_hat = y_hat.reshape(-1, 1, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE).mean((-1, -3)) > CUTOFF
    patches = y.reshape(-1, 1, h_patches, PATCH_SIZE, w_patches, PATCH_SIZE).mean((-1, -3)) > CUTOFF
    return (patches == patches_hat).float().mean()

In [None]:
folder_name = '/content/drive/MyDrive/cil-project'
num_val = VAL_SIZE*2

In [None]:
import gc
def train(train_dataloader, eval_dataloader, model, loss_fn, metric_fns, optimizer, n_epochs):
    # training loop
    logdir = '/content/drive/MyDrive/cil-project/tensorboard/net'
    writer = SummaryWriter(logdir)  # tensorboard writer (can also log images)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    history = {}  # collects metrics at the end of each epoch

    for epoch in range(n_epochs):  # loop over the dataset multiple times

        # initialize metric list
        metrics = {'loss': [], 'val_loss': []}
        for k, _ in metric_fns.items():
            metrics[k] = []
            metrics['val_'+k] = []

        pbar = tqdm(train_dataloader, desc=f'Epoch {epoch+1}/{n_epochs}')
        # training
        model.train()
        for (x, y) in pbar:
            optimizer.zero_grad()  # zero out gradients
            y_hat = model(x)  # forward pass
            loss = loss_fn(y_hat, y)
            loss.backward()  # backward pass
            optimizer.step()  # optimize weights

            # log partial metrics
            metrics['loss'].append(loss.item())
            for k, fn in metric_fns.items():
                metrics[k].append(fn(y_hat, y).item())
            pbar.set_postfix({k: sum(v)/len(v) for k, v in metrics.items() if len(v) > 0})

        # validation
        model.eval()
        with torch.no_grad():  # do not keep track of gradients
            for (x, y) in eval_dataloader:
                y_hat = model(x)  # forward pass
                loss = loss_fn(y_hat, y)
                
                # log partial metrics
                metrics['val_loss'].append(loss.item())
                for k, fn in metric_fns.items():
                    metrics['val_'+k].append(fn(y_hat, y).item())

        # summarize metrics, log to tensorboard and display
        history[epoch] = {k: sum(v) / len(v) for k, v in metrics.items()}
        for k, v in history[epoch].items():
          writer.add_scalar(k, v, epoch)
        print(' '.join(['\t- '+str(k)+' = '+str(v)+'\n ' for (k, v) in history[epoch].items()]))
        #show_val_samples(x.detach().cpu().numpy(), y.detach().cpu().numpy(), y_hat.detach().cpu().numpy())
        
        # deep copy the model
        if history[epoch]['val_acc'] > best_acc:
          # print(history[epoch]['val_acc'])
          best_acc = history[epoch]['val_acc']
          best_model_wts = copy.deepcopy(model.state_dict())
          torch.save(model.state_dict(),folder_name+'/model/model_e_temp.pt')
          
        gc.collect()

    print('Finished Training')
    print(best_acc)
    torch.save(model.state_dict(),folder_name+'/model/model_e'+str(n_epochs)+'_val'+str(num_val)+'.pt')
    model.load_state_dict(best_model_wts)
    torch.save(model.state_dict(),folder_name+'/model/best_val_acc_model_e'+str(n_epochs)+'_val'+str(num_val)+'.pt')
    # plot loss curves
    plt.plot([v['loss'] for k, v in history.items()], label='Training Loss')
    plt.plot([v['val_loss'] for k, v in history.items()], label='Validation Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epochs')
    plt.legend()
    plt.show()
    return model

In [None]:
# # Garbage Collector - use it like gc.collect()
# import gc

# # Custom Callback To Include in Callbacks List At Training Time
# class GarbageCollectorCallback(tf.keras.callbacks.Callback):
#     def on_epoch_end(self, epoch, logs=None):
#         gc.collect()

In [None]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model = UNet().to(device)
# model.load_state_dict(torch.load('/content/drive/MyDrive/cil-project/model/model_e_temp.pt'))

<All keys matched successfully>

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# reshape the image to simplify the handling of skip connections and maxpooling
# train_dataset = ImageDataset(folder_name+'/training', device, use_patches=False, resize_to=(384, 384))
# val_dataset = ImageDataset(folder_name+'/validation', device, use_patches=False, resize_to=(384, 384))
train_dataset = ImageDataset('/content/training', device, use_patches=False, resize_to=(384, 384))
val_dataset = ImageDataset('/content/validation', device, use_patches=False, resize_to=(384, 384))
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=8, shuffle=True)
model = UNet().to(device)
loss_fn = nn.BCELoss()
metric_fns = {'acc': accuracy_fn, 'patch_acc': patch_accuracy_fn}
optimizer = torch.optim.Adam(model.parameters())
n_epochs = 60

In [None]:
train(train_dataloader, val_dataloader, model, loss_fn, metric_fns, optimizer, n_epochs)

HBox(children=(FloatProgress(value=0.0, description='Epoch 1/100', max=117.0, style=ProgressStyle(description_…


	- loss = 0.3358114054824552
  	- val_loss = 0.3286077348809493
  	- acc = 0.8871972346917177
  	- val_acc = 0.8906836117568769
  	- patch_acc = 0.820252409348121
  	- val_patch_acc = 0.8395810629192152
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 2/100', max=117.0, style=ProgressStyle(description_…


	- loss = 0.3207715017418576
  	- val_loss = 0.32277032419254903
  	- acc = 0.8871915910997962
  	- val_acc = 0.8803954500901071
  	- patch_acc = 0.8192923568252825
  	- val_patch_acc = 0.7709418441119947
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 3/100', max=117.0, style=ProgressStyle(description_…


	- loss = 0.31928218251619583
  	- val_loss = 0.30521747744397115
  	- acc = 0.8871665194503262
  	- val_acc = 0.8906358088317671
  	- patch_acc = 0.8173533275596097
  	- val_patch_acc = 0.824413748163926
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 4/100', max=117.0, style=ProgressStyle(description_…


	- loss = 0.31732766419394404
  	- val_loss = 0.2961890960210248
  	- acc = 0.8871862047757858
  	- val_acc = 0.890736283440339
  	- patch_acc = 0.8245281429372282
  	- val_patch_acc = 0.8437557157717253
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 5/100', max=117.0, style=ProgressStyle(description_…


	- loss = 0.3068108030109324
  	- val_loss = 0.29840738836087677
  	- acc = 0.8879408321828923
  	- val_acc = 0.8908261750873766
  	- patch_acc = 0.8297175551072146
  	- val_patch_acc = 0.8418776487049303
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 6/100', max=117.0, style=ProgressStyle(description_…


	- loss = 0.30304672511724323
  	- val_loss = 0.30016026567471654
  	- acc = 0.8880076601973965
  	- val_acc = 0.8893365107084575
  	- patch_acc = 0.833599320334247
  	- val_patch_acc = 0.8323763593247062
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 7/100', max=117.0, style=ProgressStyle(description_…


	- loss = 0.29507267551544386
  	- val_loss = 0.3003885122506242
  	- acc = 0.8901360248908018
  	- val_acc = 0.8916921301891929
  	- patch_acc = 0.8396089372471867
  	- val_patch_acc = 0.847550199220055
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 8/100', max=117.0, style=ProgressStyle(description_…


	- loss = 0.28629668464518
  	- val_loss = 0.288387257801859
  	- acc = 0.8929995267819135
  	- val_acc = 0.9003817721417076
  	- patch_acc = 0.8468227014582381
  	- val_patch_acc = 0.8170311905835804
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 9/100', max=117.0, style=ProgressStyle(description_…


	- loss = 0.2723793391233835
  	- val_loss = 0.2582046299388534
  	- acc = 0.8989968860251272
  	- val_acc = 0.9036557501868198
  	- patch_acc = 0.8589391224404685
  	- val_patch_acc = 0.8626995635660071
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 10/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.2637513999500845
  	- val_loss = 0.26298279393660395
  	- acc = 0.9007432618711748
  	- val_acc = 0.89795488903397
  	- patch_acc = 0.8608217662216252
  	- val_patch_acc = 0.8492610140850669
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 11/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.25307041916072875
  	- val_loss = 0.23807851656487114
  	- acc = 0.9043775501414242
  	- val_acc = 0.9103661101115378
  	- patch_acc = 0.8701440849874773
  	- val_patch_acc = 0.8718549373902773
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 12/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.2414694744297582
  	- val_loss = 0.23357609267297544
  	- acc = 0.906616333203438
  	- val_acc = 0.910701866212644
  	- patch_acc = 0.875001861498906
  	- val_patch_acc = 0.8760720240442377
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 13/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.23329699383332178
  	- val_loss = 0.24540634884646065
  	- acc = 0.9104018073815566
  	- val_acc = 0.9042532098920721
  	- patch_acc = 0.8807710944077908
  	- val_patch_acc = 0.850680581833187
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 14/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.2331757974675578
  	- val_loss = 0.23651068187073657
  	- acc = 0.911073803392231
  	- val_acc = 0.9094022230098122
  	- patch_acc = 0.8832431960309672
  	- val_patch_acc = 0.8872311068208594
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 15/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.22377108954466307
  	- val_loss = 0.21755165763591466
  	- acc = 0.9146110313570398
  	- val_acc = 0.9170596113330439
  	- patch_acc = 0.8882130002364134
  	- val_patch_acc = 0.8972055880646956
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 16/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.21557820843071
  	- val_loss = 0.2274761876385463
  	- acc = 0.9168699276752961
  	- val_acc = 0.9160392080482683
  	- patch_acc = 0.8933478833263756
  	- val_patch_acc = 0.89366320559853
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 17/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.20920738991763857
  	- val_loss = 0.24415858795768336
  	- acc = 0.9191408636223557
  	- val_acc = 0.9107857785726848
  	- patch_acc = 0.8975627722903194
  	- val_patch_acc = 0.8869137450268394
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 18/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.20959017750544426
  	- val_loss = 0.2688227497825497
  	- acc = 0.9191635943885542
  	- val_acc = 0.906529923802928
  	- patch_acc = 0.8962684808633267
  	- val_patch_acc = 0.8776669957135853
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 19/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.20359412510680336
  	- val_loss = 0.21827683989938937
  	- acc = 0.921579038994944
  	- val_acc = 0.9188300543709805
  	- patch_acc = 0.9008955150587946
  	- val_patch_acc = 0.9012970344016427
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 20/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.19749648321388114
  	- val_loss = 0.20905146238050962
  	- acc = 0.924595727879777
  	- val_acc = 0.9237526231690457
  	- patch_acc = 0.9057343678596692
  	- val_patch_acc = 0.9071058254492911
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 21/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.19482189620661938
  	- val_loss = 0.21019551746155085
  	- acc = 0.9249385926458571
  	- val_acc = 0.919395093855105
  	- patch_acc = 0.9071191768360953
  	- val_patch_acc = 0.8942979229123968
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 22/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.19537055454193017
  	- val_loss = 0.19915200554226575
  	- acc = 0.9251295144741352
  	- val_acc = 0.9237950930469915
  	- patch_acc = 0.9044519459080492
  	- val_patch_acc = 0.9080938147871118
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 23/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.1883418300212958
  	- val_loss = 0.2252508595977959
  	- acc = 0.927735425468184
  	- val_acc = 0.9187853555930289
  	- patch_acc = 0.9102141230534284
  	- val_patch_acc = 0.9006264084263852
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 24/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.18398980344207877
  	- val_loss = 0.20759112407502375
  	- acc = 0.9286618548580724
  	- val_acc = 0.9244079385933123
  	- patch_acc = 0.9105164627743583
  	- val_patch_acc = 0.908878649535932
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 25/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.1801257406671842
  	- val_loss = 0.19507756005776555
  	- acc = 0.9301304837577363
  	- val_acc = 0.925758487299869
  	- patch_acc = 0.9147499004999796
  	- val_patch_acc = 0.8959932468439403
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 26/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.17434775727426904
  	- val_loss = 0.1943685931986884
  	- acc = 0.9321620016016512
  	- val_acc = 0.9291656519237318
  	- patch_acc = 0.9183701373573042
  	- val_patch_acc = 0.9058061916577188
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 27/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.1695659327456075
  	- val_loss = 0.20457242546897186
  	- acc = 0.9341131846110026
  	- val_acc = 0.9245953010885339
  	- patch_acc = 0.9204861149828658
  	- val_patch_acc = 0.9054015325872522
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 28/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.16495782843767068
  	- val_loss = 0.19077570207024874
  	- acc = 0.9360029941950089
  	- val_acc = 0.9315343671723416
  	- patch_acc = 0.9240329036345849
  	- val_patch_acc = 0.9127538988464757
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 29/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.16445634603245646
  	- val_loss = 0.22388707101345062
  	- acc = 0.9364421107830145
  	- val_acc = 0.924926199411091
  	- patch_acc = 0.923663424121009
  	- val_patch_acc = 0.906694642807308
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 30/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.15644860197590965
  	- val_loss = 0.1888030206686572
  	- acc = 0.9389042462039198
  	- val_acc = 0.9316847512596532
  	- patch_acc = 0.9280508145307883
  	- val_patch_acc = 0.9152764850541165
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 31/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.15285504348257667
  	- val_loss = 0.18912624018756966
  	- acc = 0.9402332535156837
  	- val_acc = 0.9315333084056252
  	- patch_acc = 0.929975004787119
  	- val_patch_acc = 0.9112111377088647
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 32/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.15367871605687672
  	- val_loss = 0.20866543113401062
  	- acc = 0.9401240807313186
  	- val_acc = 0.9299271294945165
  	- patch_acc = 0.9305707725704225
  	- val_patch_acc = 0.9100477507239894
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 33/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.15457387930817074
  	- val_loss = 0.20060445485930695
  	- acc = 0.9391890968013014
  	- val_acc = 0.9282839486473485
  	- patch_acc = 0.9287233739836603
  	- val_patch_acc = 0.9052767126183761
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 34/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.14527637352291337
  	- val_loss = 0.27450777139318616
  	- acc = 0.9430169906371679
  	- val_acc = 0.9198706071627768
  	- patch_acc = 0.9344955698037759
  	- val_patch_acc = 0.9003343378242693
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 35/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.13540609168191242
  	- val_loss = 0.21113041847159988
  	- acc = 0.9466830660135318
  	- val_acc = 0.931310037249013
  	- patch_acc = 0.94011900975154
  	- val_patch_acc = 0.9126625233574918
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 36/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.14907853388124043
  	- val_loss = 0.21008586687477013
  	- acc = 0.9415477937103337
  	- val_acc = 0.9235785352556329
  	- patch_acc = 0.9319822895221221
  	- val_patch_acc = 0.8957076966762543
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 37/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.13723277599893063
  	- val_loss = 0.1952277272939682
  	- acc = 0.9470084538826575
  	- val_acc = 0.9346156120300293
  	- patch_acc = 0.9403883305370299
  	- val_patch_acc = 0.9188147877392016
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 38/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.12349670520450315
  	- val_loss = 0.19642860599254308
  	- acc = 0.9512928092581594
  	- val_acc = 0.9329292460491783
  	- patch_acc = 0.9468831644098983
  	- val_patch_acc = 0.9143048164091612
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 39/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.11911514235867394
  	- val_loss = 0.2345104127338058
  	- acc = 0.9528376358187097
  	- val_acc = 0.9313735161956987
  	- patch_acc = 0.9497028640192798
  	- val_patch_acc = 0.9157129557509172
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 40/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.11297582990975462
  	- val_loss = 0.2078975097913491
  	- acc = 0.955422605204786
  	- val_acc = 0.9295373998190227
  	- patch_acc = 0.9530552656222613
  	- val_patch_acc = 0.9083434578619505
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 41/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.10709885838958952
  	- val_loss = 0.21038197275055082
  	- acc = 0.9576355794556121
  	- val_acc = 0.9326381589237013
  	- patch_acc = 0.9562251552557334
  	- val_patch_acc = 0.9165614310063814
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 42/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.10601997254495947
  	- val_loss = 0.22434887250787333
  	- acc = 0.9576821495325137
  	- val_acc = 0.933143265937504
  	- patch_acc = 0.9566476808653938
  	- val_patch_acc = 0.9163215772101754
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 43/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.1002248583568467
  	- val_loss = 0.20708566237437098
  	- acc = 0.960275674987043
  	- val_acc = 0.9336777878435034
  	- patch_acc = 0.9604478325599279
  	- val_patch_acc = 0.9177754050806949
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 44/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.09699505414718236
  	- val_loss = 0.2343260608613491
  	- acc = 0.9616972778597449
  	- val_acc = 0.9233942329883575
  	- patch_acc = 0.9619821411931616
  	- val_patch_acc = 0.9038701951503754
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 45/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.09011598787883408
  	- val_loss = 0.23173124970574127
  	- acc = 0.9645372145196311
  	- val_acc = 0.9320729092547768
  	- patch_acc = 0.9657266496593117
  	- val_patch_acc = 0.9168494265330466
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 46/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.08651639521121979
  	- val_loss = 0.2877680294607815
  	- acc = 0.9659483366542392
  	- val_acc = 0.9326914972380588
  	- patch_acc = 0.967525083794553
  	- val_patch_acc = 0.9162783309033042
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 47/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.08233775988093808
  	- val_loss = 0.2413649649212235
  	- acc = 0.9676863377929753
  	- val_acc = 0.9341947279478374
  	- patch_acc = 0.9698762542162186
  	- val_patch_acc = 0.9192936828261927
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 48/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.08260269682758893
  	- val_loss = 0.2427380496734067
  	- acc = 0.9678724036257491
  	- val_acc = 0.9331691265106201
  	- patch_acc = 0.9696239992084666
  	- val_patch_acc = 0.9184745801122565
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 49/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.07889444339606497
  	- val_loss = 0.2333555123523662
  	- acc = 0.9694601317756196
  	- val_acc = 0.9336459668059098
  	- patch_acc = 0.9716134789662484
  	- val_patch_acc = 0.9184354214291823
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 50/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.0757135029277231
  	- val_loss = 0.2628619192462218
  	- acc = 0.9706216892625532
  	- val_acc = 0.9297051821884356
  	- patch_acc = 0.9732008331861252
  	- val_patch_acc = 0.913050861735093
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 51/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.07124656206394872
  	- val_loss = 0.23470347276643702
  	- acc = 0.9726098388688177
  	- val_acc = 0.932651615456531
  	- patch_acc = 0.9750066858071548
  	- val_patch_acc = 0.9158655216819361
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 52/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.06942958031327297
  	- val_loss = 0.29437334600247833
  	- acc = 0.973206114055764
  	- val_acc = 0.9304032121834002
  	- patch_acc = 0.9759871414583973
  	- val_patch_acc = 0.9143325498229579
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 53/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.07042708894253796
  	- val_loss = 0.2450673807608454
  	- acc = 0.9728830059369405
  	- val_acc = 0.9312516090117002
  	- patch_acc = 0.9748160100390768
  	- val_patch_acc = 0.9150407016277313
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 54/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.06488018559339719
  	- val_loss = 0.26068814020407827
  	- acc = 0.9753182535497551
  	- val_acc = 0.9330332608599412
  	- patch_acc = 0.9779933151016887
  	- val_patch_acc = 0.9172271584209643
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 55/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.06275025185229433
  	- val_loss = 0.2612830247533949
  	- acc = 0.9761401206000239
  	- val_acc = 0.9337643510416934
  	- patch_acc = 0.9786955479882721
  	- val_patch_acc = 0.9190252721309662
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 56/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.060432018377842046
  	- val_loss = 0.2726055191535699
  	- acc = 0.9770924168774205
  	- val_acc = 0.9349770389105144
  	- patch_acc = 0.9799568301592118
  	- val_patch_acc = 0.9204113891250209
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 57/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.05626787221393524
  	- val_loss = 0.2647565124850524
  	- acc = 0.9790826103626153
  	- val_acc = 0.9337554809294248
  	- patch_acc = 0.9818142428357377
  	- val_patch_acc = 0.9191827303484866
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 58/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.05493665016932875
  	- val_loss = 0.2873353952247846
  	- acc = 0.979485326852554
  	- val_acc = 0.9351380075279035
  	- patch_acc = 0.9823339647716947
  	- val_patch_acc = 0.9204750264945784
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 59/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.053934701153228425
  	- val_loss = 0.28302064204686567
  	- acc = 0.9800378547774421
  	- val_acc = 0.9341606419337424
  	- patch_acc = 0.9826748814338293
  	- val_patch_acc = 0.9192414722944561
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 60/100', max=117.0, style=ProgressStyle(description…


	- loss = 0.05477976434442223
  	- val_loss = 0.2864362395515567
  	- acc = 0.9796839672276098
  	- val_acc = 0.935061807695188
  	- patch_acc = 0.9821277057003771
  	- val_patch_acc = 0.9213912141950507
 


HBox(children=(FloatProgress(value=0.0, description='Epoch 61/100', max=117.0, style=ProgressStyle(description…

KeyboardInterrupt: ignored

In [None]:
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# model = UNet().to(device)
# model.load_state_dict(torch.load('/content/drive/MyDrive/cil-project/model/model_e_temp.pt'))

<All keys matched successfully>

In [None]:
# model.load_state_dict(torch.load(folder_name+'/model/best_val_acc_model_e'+str(n_epochs)+'_val'+num_val+'.pt'))

In [None]:
test_path = '/content/drive/MyDrive/cil-project/cil-road-segmentation-2021/test_images/test_images'
def create_submission(labels,test_filenames,submission_filename):
  with open(submission_filename,'w') as f:
    f.write('id,prediction\n')
    for fn, patch_array in zip(sorted(test_filenames), test_pred):
      img_number = int(re.findall(r"\d+", fn)[-1])
      for i in range(patch_array.shape[0]):
        for j in range(patch_array.shape[1]):
          f.write("{:03d}_{}_{},{}\n".format(img_number, i*PATCH_SIZE, j*PATCH_SIZE, int(patch_array[j, i])))

In [None]:
# predict on test set
test_filenames = (glob(test_path + '/*.png'))
test_images = load_all_from_path(test_path)
batch_size = test_images.shape[0]
size = test_images.shape[1:3]
# we also need to resize the test images. This might not be the best ideas depending on their spatial resolution.
test_images = np.stack([cv2.resize(img, dsize=(384, 384)) for img in test_images], 0)
# label_images = np.zeros(test_images.shape, dtype=np.float32)
# test_images, label_images = contrast_image(test_images, label_images)
test_images = np_to_tensor(np.moveaxis(test_images, -1, 1), device)
test_pred = [model(t).detach().cpu().numpy() for t in test_images.unsqueeze(1)]
test_pred = np.concatenate(test_pred, 0)
test_pred= np.moveaxis(test_pred, 1, -1)  # CHW to HWC
test_pred = np.stack([cv2.resize(img, dsize=size) for img in test_pred], 0)  # resize to original shape
# now compute labels
test_pred = test_pred.reshape((-1, size[0] // PATCH_SIZE, PATCH_SIZE, size[0] // PATCH_SIZE, PATCH_SIZE))
test_pred = np.moveaxis(test_pred, 2, 3)
test_pred = np.round(np.mean(test_pred, (-1, -2)) > CUTOFF)

In [None]:
create_submission(test_pred, test_filenames, submission_filename=folder_name+'/predict/unet_submission_e'+str(n_epochs)+'_val'+num_val+'_best.csv')

In [None]:
loss = nn.CrossEntropyLoss()
input = torch.randn(3, 5, requires_grad=True)
print(input, input.shape)
target = torch.empty(3, dtype=torch.long).random_(5)
print(target, target.shape)
output = loss(input, target)
print(output)
output.backward()

tensor([[-1.8689,  0.1289,  0.5949, -0.0280, -1.1026],
        [ 0.8900,  0.1108, -1.1117,  0.5580, -2.3347],
        [ 0.8709,  1.0673,  0.4890, -1.2315, -0.3373]], requires_grad=True) torch.Size([3, 5])
tensor([1, 1, 4]) torch.Size([3])
tensor(1.7991, grad_fn=<NllLossBackward>)
