# Training UNet Architecture to Segment Salt Refineries

## Importing necessary packages

In [1]:
from configurations.dataset import SegmentationDataset
from configurations.model import UNet
from configurations import config
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import time
import os

## Loading Dataset

In [2]:
# load the image and mask filepaths in a sorted manner
imagePaths = sorted(list(paths.list_images(config.IMAGE_DATASET_PATH)))
maskPaths = sorted(list(paths.list_images(config.MASK_DATASET_PATH)))

## Train test spliting

In [3]:
# Partition the data into training and testing splits using 85% of
# the data for training and the remaining 15% for testing
split = train_test_split(imagePaths, maskPaths, test_size=config.TEST_SPLIT, random_state=42)

# Unpack the data split
(trainImages, testImages) = split[:2]
(trainMasks, testMasks) = split[2:]

# Write the testing image paths to disk so that we can use then
# when evaluating/testing our model
print("[INFO] saving testing image paths...")
f = open(config.TEST_PATHS, "w")
f.write("\n".join(testImages))
f.close()

[INFO] saving testing image paths...


## Define transformations

In [4]:
transforms = transforms.Compose([
    transforms.ToPILImage(),
 	transforms.Resize((config.INPUT_IMAGE_HEIGHT, config.INPUT_IMAGE_WIDTH)),
	transforms.ToTensor()
])

## Create the train and test datasets

In [5]:
trainDS = SegmentationDataset(imagePaths=trainImages, maskPaths=trainMasks, transforms=transforms)
testDS = SegmentationDataset(imagePaths=testImages, maskPaths=testMasks, transforms=transforms)

print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(testDS)} examples in the test set...")

[INFO] found 3400 examples in the training set...
[INFO] found 600 examples in the test set...


## Create the training and test data loaders

In [6]:
trainLoader = DataLoader(trainDS, shuffle=True,
	batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
	num_workers=os.cpu_count())

testLoader = DataLoader(testDS, shuffle=False,
	batch_size=config.BATCH_SIZE, pin_memory=config.PIN_MEMORY,
	num_workers=os.cpu_count())

## Loading Model

In [7]:
# Initialize our UNet model
unet = UNet().to(config.DEVICE)

# Initialize loss function and optimizer
lossFunc = BCEWithLogitsLoss()
opt = Adam(unet.parameters(), lr=config.INIT_LR)

# Calculate steps per epoch for training and test set
trainSteps = len(trainDS) // config.BATCH_SIZE
testSteps = len(testDS) // config.BATCH_SIZE

# Initialize a dictionary to store training history
H = {"train_loss": [], "test_loss": []}

## Training Epochs

In [None]:
# Loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(config.NUM_EPOCHS)):
	# set the model in training mode
	unet.train()

	# Initialize the total training and validation loss
	totalTrainLoss = 0
	totalTestLoss = 0

	# Loop over the training set
	for (i, (x, y)) in enumerate(trainLoader):
		# send the input to the device
		(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

		# Perform a forward pass and calculate the training loss
		pred = unet(x)
		loss = lossFunc(pred, y)

		# First, zero out any previously accumulated gradients, then
		# perform backpropagation, and then update model parameters
		opt.zero_grad()
		loss.backward()
		opt.step()

		# Add the loss to the total training loss so far
		totalTrainLoss += loss

	# Switch off autograd
	with torch.no_grad():
		# set the model in evaluation mode
		unet.eval()

		# Loop over the validation set
		for (x, y) in testLoader:
			# send the input to the device
			(x, y) = (x.to(config.DEVICE), y.to(config.DEVICE))

			# Make the predictions and calculate the validation loss
			pred = unet(x)
			totalTestLoss += lossFunc(pred, y)

	# Calculate the average training and validation loss
	avgTrainLoss = totalTrainLoss / trainSteps
	avgTestLoss = totalTestLoss / testSteps

	# Update our training history
	H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
	H["test_loss"].append(avgTestLoss.cpu().detach().numpy())

	# Print the model training and validation information
	print("[INFO] EPOCH: {}/{}".format(e + 1, config.NUM_EPOCHS))
	print("Train loss: {:.6f}, Test loss: {:.4f}".format(
		avgTrainLoss, avgTestLoss))

# Display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
	endTime - startTime))

[INFO] training the network...


  2%|████▏                                                                                                                                                                | 1/40 [01:43<1:07:28, 103.82s/it]

[INFO] EPOCH: 1/40
Train loss: 0.602414, Test loss: 0.6053


  5%|████████▎                                                                                                                                                            | 2/40 [03:21<1:03:22, 100.08s/it]

[INFO] EPOCH: 2/40
Train loss: 0.565257, Test loss: 0.5919


## Ploting train graph

In [None]:
# plot the training loss
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(config.PLOT_PATH)

## Save Model

In [None]:
# serialize the model to disk
torch.save(unet, config.MODEL_PATH)