In [2]:
from torchvision.models.segmentation import deeplabv3_resnet50, DeepLabV3_ResNet50_Weights
from torchvision.models.segmentation.deeplabv3 import DeepLabHead
from torchvision.utils import draw_segmentation_masks
import torchvision.transforms.functional as F

import logging

import torch 
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch.nn import DataParallel
from torch.nn import CrossEntropyLoss, MSELoss
import torch.optim as optim
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
from torchvision.io import read_image, ImageReadMode
from torchvision.datasets.vision import VisionDataset
import torchvision.transforms.functional as F

import cv2 as cv
import numpy as np
import time
import os
import logging
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import random_split
from tqdm import tqdm
import pickle

import requests
import copy
import sys
import json

from pathlib import Path
from PIL import Image
from imutils.paths import list_images, list_files
import os

# cd ~/pv-vision/

pv_vision_dir = os.path.join(Path.home(), 'pv-vision')
functions_dir = os.path.join(Path.home(), 'el_img_cracks_ec', 'my_scripts')

sys.path.append(pv_vision_dir)
sys.path.append(functions_dir)

from pv_vision.nn import ModelHandler
from tutorials.unet_model import construct_unet
import functions

from scipy import ndimage

In [4]:
# will put this method into util in the future
class SolarDataset(VisionDataset):
    """A dataset directly read images and masks from folder.    
    """
    def __init__(self, 
                 root, 
                 image_folder, 
                 mask_folder,
                 transforms,
                 mode = "train",
                 random_seed=42):
        super().__init__(root, transforms)
        self.image_path = Path(self.root) / image_folder
        self.mask_path = Path(self.root) / mask_folder

        if not os.path.exists(self.image_path):
            raise OSError(f"{self.image_path} not found.")

        if not os.path.exists(self.mask_path):
            raise OSError(f"{self.mask_path} not found.")

        self.image_list = sorted(list(list_images(self.image_path)))
        self.mask_list = sorted(list(list_images(self.mask_path)))

        self.image_list = np.array(self.image_list)
        self.mask_list = np.array(self.mask_list)

        # np.random.seed(random_seed)
        # index = np.arange(len(self.image_list))
        # np.random.shuffle(index)
        # self.image_list = self.image_list[index]
        # self.mask_list = self.mask_list[index]

    def __len__(self):
        return len(self.image_list)

    def __getname__(self, index):
        image_name = os.path.splitext(os.path.split(self.image_list[index])[-1])[0]
        mask_name = os.path.splitext(os.path.split(self.mask_list[index])[-1])[0]

        if image_name == mask_name:
            return image_name
        else:
            return False
    
    def __getraw__(self, index):
        if not self.__getname__(index):
            raise ValueError("{}: Image doesn't match with mask".format(os.path.split(self.image_list[index])[-1]))
        image = Image.open(self.image_list[index])
        mask = Image.open(self.mask_list[index]).convert('L')
        mask = np.array(mask)
        mask = Image.fromarray(mask)

        return image, mask

    def __getitem__(self, index):
        image, mask = self.__getraw__(index)
        image, mask = self.transforms(image, mask)

        return image, mask

# will put into utils in the future
class Compose:
    def __init__(self, transforms):
        """
        transforms: a list of transform
        """
        self.transforms = transforms
    
    def __call__(self, image, target):
        """
        image: input image
        target: input mask
        """
        for t in self.transforms:
            image, target = t(image, target)
        return image, target

class FixResize:
    # UNet requires input size to be multiple of 16
    def __init__(self, size):
        self.size = size

    def __call__(self, image, target):
        image = F.resize(image, (self.size, self.size), interpolation=transforms.InterpolationMode.BILINEAR)
        target = F.resize(target, (self.size, self.size), interpolation=transforms.InterpolationMode.NEAREST)
        return image, target

class ToTensor:
    """Transform the image to tensor. Scale the image to [0,1] float32.
    Transform the mask to tensor.
    """
    def __call__(self, image, target):
        image = transforms.ToTensor()(image)
        target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target

class PILToTensor:
    """Transform the image to tensor. Keep raw type."""
    def __call__(self, image, target):
        image = F.pil_to_tensor(image)
        target = torch.as_tensor(np.array(target), dtype=torch.int64)
        return image, target

class Normalize:
    def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
        self.mean = mean
        self.std = std
    
    def __call__(self, image, target):
        image = F.normalize(image, mean=self.mean, std=self.std)
        return image, target

In [3]:
root = Path('/projects/wg-psel-ml/EL_images/eccoope')
transformers = functions.Compose([functions.FixResize(256), functions.ToTensor(), functions.Normalize()])

trainset = functions.SolarDataset(root, image_folder="img/train", 
        mask_folder="ann/train", transforms=transformers)

valset = functions.SolarDataset(root, image_folder="img/val", 
        mask_folder="ann/val", transforms=transformers)

In [9]:
cmap = mpl.colormaps['viridis'].resampled(5)  # define the colormap
cmaplist = [cmap(i) for i in range(5)]

for i in range(261):

# i = 8

    img, mask = trainset. __getitem__(i)
    mask_cpu = mask.cpu().numpy()

    raw_img, _ = trainset. __getraw__(i)

    fig, ax = plt.subplots(ncols=2, figsize=(12, 7), layout='compressed')

    ax[0].imshow(raw_img, cmap='gray')
    ax[0].axis('off')

    im = ax[1].imshow(mask_cpu, cmap='viridis')
    ax[1].axis('off')

    handles, labels = ax[1].get_legend_handles_labels()

    for c, classlabel in zip(cmaplist, ['(0) empty', '(1) dark', '(2) cross', '(3) crack', '(4) busbar']):
        patch = mpatches.Patch(color=c, label=classlabel, ec='k')
        handles.append(patch)
    ax[1].legend(handles=handles, fontsize='x-small')

    plt.savefig(os.path.join(Path.home(), 'el_img_cracks_ec', 'asu_targets', trainset. __getname__(i) + '.jpg'))
    plt.close()

In [10]:
cmap = mpl.colormaps['viridis'].resampled(5)  # define the colormap
cmaplist = [cmap(i) for i in range(5)]

for i in range(87):

# i = 8

    img, mask = valset. __getitem__(i)
    mask_cpu = mask.cpu().numpy()

    raw_img, _ = valset. __getraw__(i)

    fig, ax = plt.subplots(ncols=2, figsize=(12, 7), layout='compressed')

    ax[0].imshow(raw_img, cmap='gray')
    ax[0].axis('off')

    im = ax[1].imshow(mask_cpu, cmap='viridis')
    ax[1].axis('off')

    handles, labels = ax[1].get_legend_handles_labels()

    for c, classlabel in zip(cmaplist, ['(0) empty', '(1) dark', '(2) cross', '(3) crack', '(4) busbar']):
        patch = mpatches.Patch(color=c, label=classlabel, ec='k')
        handles.append(patch)
    ax[1].legend(handles=handles, fontsize='x-small')

    plt.savefig(os.path.join(Path.home(), 'el_img_cracks_ec', 'asu_targets', trainset. __getname__(i) + '.jpg'))
    plt.close()

In [11]:
264/4

66.0