In [None]:
import os
import sys
from pathlib import Path
from torch.nn import DataParallel
from torch.utils.data import Dataset, DataLoader
import torch
import numpy as np
from scipy import ndimage
import skimage

import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from PIL import Image
import cv2
import functions
from tutorials.unet_model import construct_unet

In [None]:
root = Path('/projects/wg-psel-ml/EL_images/osanghi/CornersIHDEANE')
transformers = functions.Compose([functions.FixResize(256), functions.ToTensor(), functions.Normalize()])

train_dataset = functions.SolarDataset(root, image_folder="img/train", 
        mask_folder="ann/train", transforms=transformers)

test_dataset = functions.SolarDataset(root, image_folder="img/test", 
        mask_folder="ann/test", transforms=transformers)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
unet = construct_unet(5)
unet = torch.nn.DataParallel(unet)

weight_path = '/projects/wg-psel-ml/EL_images/osanghi/CornersIHDEANE/checkpoints/retrain_corners_checkpoint3/epoch_30/model.pt'

"""
checkpoint = torch.load(weight_path, map_location=torch.device('cpu'))
from collections import OrderedDict

new_state_dict = OrderedDict()
for k, v in checkpoint.items():
    name = "module." + k
    new_state_dict[name] = v

unet.load_state_dict(new_state_dict)
"""

unet.load_state_dict(torch.load(weight_path, map_location=torch.device('cpu')), strict=False)
model = unet.module.to(device)

In [None]:
batch_size_test=1
batch_size_train=1

train_loader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False)

In [None]:
category_mapping = {0: "empty", 1: "dark", 2: "busbar", 3: "crack", 4: "corner"}

In [None]:
def inference_and_show(idx, retrained=False):
    img, mask = train_loader.dataset. __getitem__(idx)
    img = img.to(device)
    raw_img, _ = train_loader.dataset. __getraw__(idx)
    test_res = model(img.unsqueeze(0)).detach().cpu().numpy().squeeze()#.argmax(axis = 0)
    test_res = np.argmax(test_res, axis = 0)

    mask_cpu = mask.cpu().numpy()

    cmap = mpl.colormaps['viridis'].resampled(5)  # define the colormap
    cmaplist = [cmap(i) for i in range(5)]

    fig, ax = plt.subplots(ncols=3, figsize=(12,12))

    im = ax[0].imshow(raw_img.convert('L'), cmap='gray', interpolation='None')
    ax[0].axis('off')

    clim = (0, 4)
    im = ax[1].imshow(mask_cpu, cmap = 'viridis', clim=clim)
    ax[1].axis('off')
    ax[1].set_title("Ground Truth Mask")

    ax[2].imshow(test_res, cmap = 'viridis', clim = clim, interpolation='None')
    handles, labels = ax[2].get_legend_handles_labels()

    for c, classlabel in zip(cmaplist, [f'({k}) {v}' for k, v in category_mapping.items()]):
            patch = mpatches.Patch(color=c, label=classlabel, ec='k')
            handles.append(patch)
    ax[2].legend(handles=handles, fontsize='x-small')
    ax[2].axis('off')
    if retrained:
        ax[2].set_title("Retrained Model Prediction")
    else:
        ax[2].set_title("Model Prediction")