In [3]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
import os

from scipy.ndimage.morphology import binary_dilation as dilate
from scipy.ndimage.morphology import binary_erosion as erode
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.cluster import MiniBatchKMeans
from data_helper import *
from model_helper import *
from skimage import transform
from tqdm import tqdm
#from efficientnet_pytorch import EfficientNet

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
device = 'cuda:0'

model = torch.hub.load('facebookresearch/WSL-Images', 'resnext101_32x8d_wsl')
#model = EfficientNet.from_pretrained('efficientnet-b5')
model = model.to(device)

data_set = 'wood'
defect = 'hole'
patch_dim = 16
train_size = 1
n_clusters = 10
stride = 4

Using cache found in /home/ubuntu/.cache/torch/hub/facebookresearch_WSL-Images_master


## Training
Builds the feature dictionary.

In [5]:
train = give_train(data_set, train_size)
patches = get_patches(patch_dim, stride, train)
print(patches.shape)

100%|██████████| 1/1 [00:00<00:00,  6.58it/s]


512
(15625, 16, 16, 3)


In [None]:
features = np.concatenate([get_feature(batch, model, device=device) for batch in tqdm(patches)])
features.shape

100%|██████████| 15625/15625 [10:25<00:00, 24.97it/s]


In [None]:
pca = PCA(n_components=0.95,svd_solver='full')
X = pca.fit_transform(features)
print(X.shape)
X = (X - X.mean()) / X.std()
kmeans = KMeans(n_clusters=n_clusters, init='k-means++').fit(X)

In [None]:
patch_labels = kmeans.labels_
clusters = [np.array([]) for i in range(n_clusters)]
for idx,feature in enumerate(X):
    label = patch_labels[idx]
    clusters[label] = np.append(clusters[label], feature)
for i in range(n_clusters):
    clusters[i] = clusters[i].reshape(-1,X.shape[1])

centroids = kmeans.cluster_centers_
feature_dict = np.zeros((n_clusters, X.shape[1]))

for i,cluster in enumerate(clusters):
    represent = (cluster - centroids[i])
    index = np.argmin(np.diagonal(represent.dot(represent.T)))
    feature_dict[i] = cluster[index]

## Testing

Tests the algorithm for one image.

In [None]:
#USE index TO SET THE IMAGE TO BE TESTED

anomaly_path = "./" + data_set + "/test/" + defect + "/"
mask_path = "./" + data_set + "/ground_truth/" + defect + "/"
index = 5

files = os.listdir(anomaly_path)
dim = np.asarray(Image.open(anomaly_path+files[index])).shape[0] // 2

test_anomaly = resize(np.asarray(Image.open(anomaly_path+files[index])), (dim, dim)).reshape(512,512,3)
mask_anomaly = resize(np.asarray(Image.open(mask_path+files[index][:-4]+"_mask.png")), (dim, dim)).reshape(512,512)


print(dim)
plt.imshow(test_anomaly)
plt.show()
plt.imshow(mask_anomaly)
plt.show()

In [None]:
test_patches = get_patches(patch_dim, stride, test_anomaly.reshape((1,512,512,3)))
print(test_patches.shape)

In [None]:
test_features = np.concatenate([get_feature(batch, model, device=device) for batch in tqdm(test_patches)])

In [None]:
X_test = pca.transform(test_features)
X_test = (X_test - np.mean(X_test)) / np.std(X_test)
print(X_test.shape)

In [None]:
m = 3

d = np.zeros((X_test.shape[0],1))

for i,feature in enumerate(X_test):
    difference = feature_dict - feature
    distances = np.diagonal(difference.dot(difference.T))
    distances = np.sort(distances)
    d[i] = np.mean(distances[:m])

In [None]:
pw = ((test_anomaly.shape[1] - patch_dim) // stride) + 1
print(pw)

mask_sum = np.zeros((2,512,512))

for ind,dif in enumerate(d):
    x = stride*(ind % pw)
    y = stride*(ind // pw)
    add_patch = np.ones((2,patch_dim,patch_dim))
    add_patch[0] = add_patch[0]*dif
    mask_sum[:, y:y+16, x:x+patch_dim] = mask_sum[:, y:y+patch_dim, x:x+patch_dim] + add_patch

mask_map = mask_sum[0,:,:] / mask_sum[1,:,:]

In [None]:
plt.imshow(mask_map)
plt.show()
plt.imshow(test_anomaly)
plt.show()

In [None]:
sig = d.std()
mu = d.mean()
alpha = 1.25

th = mu + alpha*sig
iterate=5

mask_final = np.heaviside((mask_map - th), 1) * 255
mask_final = erode(dilate(mask_final, iterations=iterate), iterations=iterate)
plt.imshow(mask_final)
plt.show()
plt.imshow(mask_anomaly)

In [None]:
pred = mask_final.astype('uint8')
val = mask_anomaly.astype('uint8')
accuracy = np.logical_not(np.logical_xor(pred, val)).sum()/(np.ones(val.shape)).sum()
IoU = np.logical_and(val,pred).sum()/np.logical_or(val,pred).sum()
tp_rate = np.logical_and(val,pred).sum()/val.sum()

In [None]:
print(accuracy)
print(IoU)
print(tp_rate)