In [74]:
import os
import sys
from skimage import io, filters
from sklearn.model_selection import train_test_split
import torch
from torchvision import transforms
from PIL import Image
import torchvision.transforms.functional as F
from deepchecks.vision.vision_data import BatchOutputFormat
import numpy as np
import random
from torch import nn
from torch.nn import functional as F
from torchvision.transforms import CenterCrop
from torch.utils.data import DataLoader
from deepchecks.vision import VisionData
from deepchecks.vision.suites import model_evaluation, train_test_validation, data_integrity

In [75]:
PROCESSED = [
    '../data/processed/datasets_processed/training_data/',
    '../data/processed/datasets_processed/training_labels/',
    '../data/processed/datasets_processed/testing_data/',
    '../data/processed/datasets_processed/testing_labels/'
]
RATIO = 1
MODEL_PATH = '/models/saved/10_50_0.1_0.25_unet_model.pth'
BEST_MODEL = 'https://cdn.albertovalerio.com/datasets/crop_segmentation/models/my_unet_model.pth'
if torch.backends.mps.is_available():
    DEVICE = 'mps'
elif torch.cuda.is_available():
    DEVICE = 'cuda'
else:
    DEVICE = 'cpu'
LABEL_MAP = {0: 'no-weed', 1: 'weed'}

In [77]:
ko = '.DS_Store'
train_i = os.listdir(PROCESSED[0])
train_m = os.listdir(PROCESSED[1])
test_i = os.listdir(PROCESSED[2])
test_m = os.listdir(PROCESSED[3])
if ko in train_i: train_i.remove(ko)
if ko in train_m: train_m.remove(ko)
if ko in test_i: test_i.remove(ko)
if ko in test_m: test_m.remove(ko)

X_train, X_val, y_train, y_val = train_test_split(
	test_i,
	test_m,
	train_size=.7,
	test_size=.3,
	random_state=3,
	shuffle=True
)

sampleT = int(len(X_train) * RATIO)
sampleV = int(len(X_val) * RATIO)

train_d = [[io.imread(PROCESSED[2]+v), io.imread(PROCESSED[3]+y_train[:sampleT][i])] for i, v in enumerate(X_train[:sampleT]) if io.imread(PROCESSED[2]+v).any()]
prod_d = [[io.imread(PROCESSED[2]+v), io.imread(PROCESSED[3]+y_val[:sampleV][i])] for i, v in enumerate(X_val[:sampleV]) if io.imread(PROCESSED[2]+v).any()]

In [78]:
def production_data(dataset, low, high):

    random.seed()
    for img in dataset:
        img[0] = filters.gaussian(img[0], sigma = 2, channel_axis=-1)
        x = random.randint(0,100)
        if x%2==0:
            img[0] = img[0]*low
        else:
            img[0] = img[0]*high
        img[0] *= 255
        img[0] = img[0].astype(np.uint8)
      
    return dataset

In [79]:
production_d = production_data(prod_d, 0.5, 2.0)

In [80]:
transform = transforms.Compose([
	transforms.ToPILImage(),
	transforms.PILToTensor(),
	transforms.ConvertImageDtype(torch.float)
])
train_dataset = [[transform(i[0]), transform(i[1])[0]] for i in train_d]
production_dataset = [[transform(i[0]), transform(i[1])[0]] for i in production_d]

In [81]:
print(f'Number of training images: {len(train_dataset)}')
print(f'Number of test images: {len(production_dataset)}')
print(f'Example output of an image shape: {train_dataset[0][0].shape}')
print(f'Example output of a label shape: {train_dataset[0][1].shape}')

Number of training images: 71
Number of test images: 31
Example output of an image shape: torch.Size([3, 360, 480])
Example output of a label shape: torch.Size([360, 480])


In [82]:
class Block(nn.Module):
	def __init__(self, inChannels, outChannels):
		super().__init__()
		self.conv1 = nn.Conv2d(inChannels, outChannels, 3)
		self.relu = nn.ReLU()
		self.conv2 = nn.Conv2d(outChannels, outChannels, 3)

	def forward(self, x):
		return self.conv2(self.relu(self.conv1(x)))

In [83]:
class Encoder(nn.Module):
	def __init__(self, channels=(3, 16, 32, 64)):
		super().__init__()
		self.encBlocks = nn.ModuleList(
			[Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]
		)
		self.pool = nn.MaxPool2d(2)

	def forward(self, x):
		blockOutputs = []
		for block in self.encBlocks:
			x = block(x)
			blockOutputs.append(x)
			x = self.pool(x)
		return blockOutputs

In [84]:
class Decoder(nn.Module):
	def __init__(self, channels=(64, 32, 16)):
		super().__init__()
		self.channels = channels
		self.upconvs = nn.ModuleList(
			[nn.ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
				for i in range(len(channels) - 1)]
		)
		self.dec_blocks = nn.ModuleList(
			[Block(channels[i], channels[i + 1]) for i in range(len(channels) - 1)]
		)

	def forward(self, x, encFeatures):
		for i in range(len(self.channels) - 1):
			x = self.upconvs[i](x)
			encFeat = self.crop(encFeatures[i], x)
			x = torch.cat([x, encFeat], dim=1)
			x = self.dec_blocks[i](x)
		return x

	def crop(self, encFeatures, x):
		(_, _, H, W) = x.shape
		encFeatures = CenterCrop([H, W])(encFeatures)
		return encFeatures

In [85]:
class UNet(nn.Module):
	def __init__(
		self, outSize, encChannels=(3, 16, 32, 64),
		decChannels=(64, 32, 16), retainDim=True
	):
		super().__init__()
		self.encoder = Encoder(encChannels)
		self.decoder = Decoder(decChannels)
		self.head = nn.Conv2d(decChannels[-1], 1, 1)
		self.retainDim = retainDim
		self.outSize = outSize

	def forward(self, x):
		encFeatures = self.encoder(x)
		decFeatures = self.decoder(encFeatures[::-1][0], encFeatures[::-1][1:])
		myMap = self.head(decFeatures)
		if self.retainDim:
			myMap = F.interpolate(myMap, self.outSize)
		return myMap

In [86]:
if not os.path.exists('my_unet_model.pth'):
    os.system('wget %s'%BEST_MODEL)
model = torch.load('my_unet_model.pth', map_location=torch.device(DEVICE))

In [87]:
def make_predictions(model, input):
	transform = transforms.Compose([
		transforms.ToPILImage(),
		transforms.PILToTensor(),
		transforms.ConvertImageDtype(torch.float)
	])
	if isinstance(input, str):
		image = io.imread(input)
	else:
		image = input
	image = transform(image)
	image = image.unsqueeze(0)
	model.eval()
	with torch.no_grad():
		predMask = model(image.to(DEVICE)).squeeze()
		predMask = torch.sigmoid(predMask)
		predMask = predMask.cpu().numpy()
		predMask = (predMask > .5) * 255
		predMask = predMask.astype(np.uint8)
	return predMask

def convert_prediction(mask):
	x = mask.shape[0]
	y = mask.shape[1]
	mask_1 = np.zeros((x, y), dtype='float64')
	mask_2 = np.zeros((x, y), dtype='float64')
	for i in range(x):
		for j in range(y):
			if mask[i][j] == 0:
				mask_1[i][j] = 1.
			if mask[i][j] == 255:
				mask_2[i][j] = 1.
	return nn.functional.softmax(torch.tensor(np.array([mask_1, mask_2])), dim=0)

In [88]:
def deepchecks_collate_fn(batch) -> BatchOutputFormat:

    # batch received as iterable of tuples of (image, label) and transformed to tuple of iterables of images and labels:
    batch = tuple(zip(*batch))

    # images:
    images = [(tensor.numpy().transpose((1, 2, 0)) * 255).astype(np.uint8) for tensor in batch[0]]

    #labels:
    labels = [mask.type(torch.int8) for mask in batch[1]]

    #predictions:
    predictions = [make_predictions(model, img) for img in images]
    predictions = [convert_prediction(pred) for pred in predictions]

    return BatchOutputFormat(images=images, labels=labels, predictions=predictions)

In [89]:
train_loader = DataLoader(dataset=train_dataset, shuffle=True, collate_fn=deepchecks_collate_fn)
production_loader = DataLoader(dataset=production_dataset, shuffle=True, collate_fn=deepchecks_collate_fn)

training_data = VisionData(batch_loader=train_loader, task_type='semantic_segmentation', label_map=LABEL_MAP)
production_data = VisionData(batch_loader=production_loader, task_type='semantic_segmentation', label_map=LABEL_MAP)

In [90]:
production_data.head()

VBox(children=(HTML(value='<div style="display:flex; flex-direction: column; gap: 10px;">\n                <di…

In [96]:
from deepchecks.vision.checks import PredictionDrift, ImageDatasetDrift

check = PredictionDrift()
result = check.run(train_dataset=training_data, test_dataset=production_data)

In [97]:
result.show()
result.save_as_html('../reports/linters/6.2.3-report-deepchecks-production-prediction-drift.html')

VBox(children=(HTML(value='<h4><b>Prediction Drift</b></h4>'), HTML(value='<p>    Calculate prediction drift b…

'../reports/linters/6.2.3-report-deepchecks-production-prediction-drift.html'

In [93]:
check = ImageDatasetDrift()
result = check.run(train_dataset=training_data, test_dataset=production_data)

In [95]:
result.show()
result.save_as_html('../reports/linters/6.2.4-report-deepchecks-production-dataset-drift.html')

VBox(children=(HTML(value='<h4><b>Image Dataset Drift</b></h4>'), HTML(value='<p>Calculate drift between the e…

'../reports/linters/6.2.4-report-deepchecks-production-dataset-drift.html'