import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torch.utils.data import Dataset as BaseDataset
import albumentations as albu
import torch
import numpy as np
import segmentation_models_pytorch as smp
import matplotlib
import scipy
from segmentation_models_pytorch import utils as smp_utils
import pandas as pd
from os.path import exists
from natsort import natsorted
from torchviz import make_dot
import matplotlib.pyplot as plt
from multiprocessing import Pool

In [2]:
DATA_DIR = './'
x_train_dir = os.path.join(DATA_DIR, 'train')
y_train_dir = os.path.join(DATA_DIR, 'train_annot')

x_valid_dir = os.path.join(DATA_DIR, 'train')
y_valid_dir = os.path.join(DATA_DIR, 'train_annot')

x_test_dir = os.path.join(DATA_DIR, 'train')
y_test_dir = os.path.join(DATA_DIR, 'train_annot')

In [3]:
names = ['C', 'Ca', 'Mg', 'Na', 'O', 'S', 'Cl']
DATA_DIR = './'
EDS_train_dir = os.path.join(DATA_DIR, 'EDS_output')

In [4]:
image = cv2.imread('./EDS_output/1_C_gt.png')

In [5]:
def accu(i, j, image_gt):
    for m in range(i-3, i+4):
        for n in range(j-3, j+4):
            if m >= 0 and m <= image_gt.shape[0]-1:
                if n >= 0 and n <= image_gt.shape[1]-1:
                    if image_gt[m, n] == 255:
                        return 1
    return 0

In [6]:
def fal_pos(i, j, image_gt):
    for m in range(i-3, i+4):
        for n in range(j-3, j+4):
            if m >= 0 and m <= image_gt.shape[0]-1:
                if n >= 0 and n <= image_gt.shape[1]-1:
                    if image_gt[m, n] == 255:
                        return 0
    return 1

In [7]:
def miss_pos(i, j, image_pr):
    for m in range(i-3, i+4):
        for n in range(j-3, j+4):
            if m >= 0 and m <= image_pr.shape[0]-1:
                if n >= 0 and n <= image_pr.shape[1]-1:
                    if image_pr[m, n] == 255:
                        return 0
    return 1

In [8]:
def accu_fal_miss(image_gt, image_pr):
    signals_pr = 0
    signals_gt = 0
    
    accuracy = 0
    false_pos = 0
    missing_pos = 0
    counter = 0
    
    for i in range(image_gt.shape[0]):
        for j in range(image_gt.shape[1]):
            if image_pr[i, j] == 255:
                signals_pr = signals_pr + 1
                accuracy = accuracy + accu(i, j, image_gt)
                false_pos = false_pos + fal_pos(i, j, image_gt)

            if image_gt[i, j] == 255:
                signals_gt = signals_gt + 1
                missing_pos = missing_pos + miss_pos(i, j, image_pr)
    return accuracy/signals_pr, false_pos/signals_pr, missing_pos/signals_gt

In [9]:
# for i in range(1):
#     for j in range(7):
#         file_name_gt = os.path.join(EDS_train_dir, str(i+1)+'_'+names[j]+'_gt.png')
#         file_name_pr = os.path.join(EDS_train_dir, str(i+1)+'_'+names[j]+'_pr.png')
#         # image_gt = cv2.imread(file_name_gt)[:, :, 0]
#         # image_pr = cv2.imread(file_name_pr)[:, :, 0] # (220, 293)
#         image_gt = cv2.imread(file_name_gt)
#         image_pr = cv2.imread(file_name_pr)
#         # accuracy, false_pos, missing_pos = accu_fal_miss(image_gt, image_pr)
#         # print(accuracy, false_pos, missing_pos)

In [10]:
class Dataset(BaseDataset):
    def __init__(self, images_dir, masks_dir, augmentation=None, preprocessing=None):
        # list file names in the self.ids list
        self.sem_ids = os.listdir(images_dir)
        self.label_ids = os.listdir(masks_dir)
        
        self.sem_ids.sort()
        self.label_ids.sort()
        
        if self.sem_ids[0].startswith('.'):
            self.sem_ids.pop(0)
            
        if self.label_ids[0].startswith('.'):
            self.label_ids.pop(0)
        
        self.sem_ids = natsorted(self.sem_ids)
        self.label_ids = natsorted(self.label_ids)
        
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.sem_ids]
        self.masks_fps = [os.path.join(masks_dir, label_id) for label_id in self.label_ids]
        self.augmentation = augmentation
        self.preprocessing = preprocessing
    
    def __getitem__(self, i):
        image = cv2.imread(self.images_fps[i])
        # print(image.shape)
        # converting the file dimension in [N, C, H, W] order
        # image = np.transpose(image, (2, 0, 1))

        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mat = scipy.io.loadmat(self.masks_fps[i])
        mask = mat['label']
        mask = np.transpose(mask, (1, 2, 0))
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=image, mask=mask)
            image, mask = sample['image'], sample['mask']
        return image, mask
        
    def __len__(self):
        return len(self.sem_ids)

In [11]:
torch.cuda.empty_cache()
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multiclass segmentation
DEVICE = 'cuda'

In [12]:
def get_validation_augmentation():
    test_transform = [
        albu.PadIfNeeded(384, 480)
    ]
    return albu.Compose(test_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    _transform = [
        # albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)
preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [13]:
# create test dataset
test_dataset = Dataset(
    x_test_dir, 
    y_test_dir, 
    augmentation=get_validation_augmentation(), 
    preprocessing=get_preprocessing(preprocessing_fn),
)

In [14]:
data = []

In [15]:
for i in range(len(test_dataset)):
    temp = []
    for j in range(7):
        image, gt_mask = test_dataset[i]
        gt_mask = gt_mask.squeeze()
        gt_mask = np.transpose(gt_mask, (1, 2, 0)) # (1920, 2560, 7)
        x_tensor = torch.from_numpy(image).to(DEVICE).unsqueeze(0)
        x_tensor = x_tensor.float()
        best_model = torch.load('./best_model.pth')
        pr_mask = best_model.predict(x_tensor)
        pr_mask = pr_mask.squeeze().cpu().numpy().round()
        image = np.transpose(image, (1, 2, 0))
        pr_mask = np.transpose(pr_mask, (1, 2, 0)) # (1920, 2560, 7)
        image_gt = gt_mask[:, :, j]*255
        image_pr = pr_mask[:, :, j]*255
        temp.append([image_gt, image_pr])
    temp = np.asarray(temp)
    with open('metric_output_temp/' + str(i+1) + '.npy', 'wb') as f:
        np.save(f, temp)

In [None]:
def f(i):
    with open('metric_output_temp/' + str(i) + '.npy', 'rb') as f:
        images = np.load(f)
    temp = []
    for i in range(7):
        image_gt, image_pr = images[i]
        accuracy, false_pos, missing_pos = accu_fal_miss(image_gt, image_pr)
        temp.append([accuracy, false_pos, missing_pos])
    return temp

with Pool(12) as p:
    result = p.map(f, list(range(1, 626)))

In [None]:
print(len(result))

In [None]:
metric_results = np.asarray(result)

In [None]:
with open('EDS_metric_results.npy', 'wb') as f:
    np.save(f, metric_results)
with open('EDS_metric_results.npy', 'rb') as f:
    metric_results = np.load(f)