# Unet 腫瘤 segmentation

## Dataset

In [None]:
# import the necessary packages
from torch.utils.data import Dataset
import cv2
class SegmentationDataset(Dataset):
	def __init__(self, imagePaths, maskPaths, transforms):
		# store the image and mask filepaths, and augmentation
		# transforms
		self.imagePaths = imagePaths
		self.maskPaths = maskPaths
		self.transforms = transforms
	def __len__(self):
		# return the number of total samples contained in the dataset
		return len(self.imagePaths)
	def __getitem__(self, idx):
		# grab the image path from the current index
		imagePath = self.imagePaths[idx]
		# load the image from disk, swap its channels from BGR to RGB,
		# and read the associated mask from disk in grayscale mode
		image = cv2.imread(imagePath)
		image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
		mask = cv2.imread(self.maskPaths[idx], 0)
		# check to see if we are applying any transformations
		if self.transforms is not None:
			# apply the transformations to both image and its mask
			image = self.transforms(image)
			mask = self.transforms(mask)
		# return a tuple of the image and its mask
		return (image, mask)

## Config

In [None]:
# import the necessary packages
import torch
import os
import glob

# Create directory
IMAGE_DATASET_PATH = glob.glob(os.path.join('dataset', 'image', '*.jpg'))
MASK_DATASET_PATH = glob.glob(os.path.join('dataset', 'mask', '*.jpg'))

MODEL_PATH = os.path.join('output', "unet.pth")
PLOT_PATH = os.path.sep.join(['output', "plot.png"])
TEST_PATHS = os.path.sep.join(['output', "test_paths.txt"])

# define the test split
TEST_SPLIT = 0.15

# determine the device to be used for training and evaluation
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# determine if we will be pinning memory during data loading
PIN_MEMORY = True if DEVICE == "cuda" else False

# # define the number of channels in the input, number of classes,
# # and number of levels in the U-Net model
# NUM_CHANNELS = 3
# NUM_CLASSES = 1
# NUM_LEVELS = 3

# initialize learning rate, number of epochs to train for, and the
# batch size
INIT_LR = 0.001
NUM_EPOCHS = 40
BATCH_SIZE = 32 #64
# define the input image dimensions
INPUT_IMAGE_WIDTH = 32*7
INPUT_IMAGE_HEIGHT = 32*4
# define threshold to filter weak predictions
THRESHOLD = 0.5

In [None]:
import torch, gc

gc.collect()
torch.cuda.empty_cache()

In [None]:
torch.cuda.memory_summary(device=None, abbreviated=False)

## Train

In [None]:
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import time
import os

In [None]:
# load the image and mask filepaths in a sorted manner
imagePaths = sorted(IMAGE_DATASET_PATH)
maskPaths = sorted(MASK_DATASET_PATH)

# partition the data into training and testing splits using 85% of
# the data for training and the remaining 15% for testing
split = train_test_split(imagePaths, maskPaths,
                         test_size=TEST_SPLIT, random_state=42)

In [None]:
# unpack the data split
(trainImages, testImages) = split[:2]
(trainMasks, testMasks) = split[2:]

# write the testing image paths to disk so that we can use then
# when evaluating/testing our model
print("[INFO] saving testing image paths...")
f = open(TEST_PATHS, "w")
f.write("\n".join(testImages))
f.close()

In [None]:
# define transformations
transforms = transforms.Compose([transforms.ToPILImage(),
                                 transforms.Resize((INPUT_IMAGE_HEIGHT, INPUT_IMAGE_WIDTH)),  
                                 transforms.ToTensor()])

# create the train and test datasets
trainDS = SegmentationDataset(imagePaths=trainImages, maskPaths=trainMasks,
                              transforms=transforms)
testDS = SegmentationDataset(imagePaths=testImages, maskPaths=testMasks,
                             transforms=transforms)
print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(testDS)} examples in the test set...")

In [None]:
# create the training and test data loaders
trainLoader = DataLoader(trainDS, shuffle=True,
                         batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,
                         )  # num_workers=os.cpu_count()
testLoader = DataLoader(testDS, shuffle=False,
                        batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,
                        )  # num_workers=os.cpu_count()

In [None]:
import segmentation_models_pytorch as smp
model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
)
model = model.to(DEVICE)


In [None]:
from segmentation_models_pytorch.encoders import get_preprocessing_fn
preprocess_input = get_preprocessing_fn('resnet34', pretrained='imagenet')

In [None]:
# initialize loss function and optimizer
lossFunc = BCEWithLogitsLoss()
opt = Adam(model.parameters(), lr=INIT_LR)

# calculate steps per epoch for training and test set
trainSteps = len(trainDS) // BATCH_SIZE
testSteps = len(testDS) // BATCH_SIZE

In [None]:
# initialize a dictionary to store training history
H = {"train_loss": [], "test_loss": []}

# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(NUM_EPOCHS)):
    # set the model in training mode
    model.train()

    # initialize the total training and validation loss
    totalTrainLoss = 0
    totalTestLoss = 0

    # loop over the training set
    for (i, (x, y)) in enumerate(trainLoader):
        # send the input to the device
        (x, y) = (x.to(DEVICE), y.to(DEVICE))

        # perform a forward pass and the training loss
        pred = model(x)
        loss = lossFunc(pred, y)

        # first, zero out any previously accumulated gradients, then
        # perform backpropagation, and then update model parameters
        opt.zero_grad()
        loss.backward()
        opt.step()

        # add the loss to the total training loss so far
        totalTrainLoss += loss
        
        # accuracy
        # tp, fp, fn, tn = smp.metrics.get_stats(pred, y.int(), mode='binary', threshold=0.5)
        # accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
        # iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
        # f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
        # print('accuracy= ', accuracy)
        # print('iou_score= ', iou_score)
        # print('f1_score= ', f1_score)
        
    # switch off autograd
    with torch.no_grad():
        # set the model in evaluation mode
        model.eval()

        # loop over the validation set
        for (x, y) in testLoader:
            # send the input to the device
            (x, y) = (x.to(DEVICE), y.to(DEVICE))

            # make the predictions and calculate the validation loss
            pred = model(x)
            totalTestLoss += lossFunc(pred, y)

    # calculate the average training and validation loss
    avgTrainLoss = totalTrainLoss / trainSteps
    avgTestLoss = totalTestLoss / testSteps
    
    # accuracy = 100*num_correct / len(trainDS)
    

    # update our training history
    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    H["test_loss"].append(avgTestLoss.cpu().detach().numpy())

    # print the model training and validation information
    print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
    print("Train loss: {:.6f}, Test loss: {:.4f}, Acc: {:.6f}".format(
        avgTrainLoss, avgTestLoss, ))
     


# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
    endTime - startTime))

print('min train_loss: ', min(H["train_loss"]))
print('max test_loss: ', min(H["test_loss"]))

In [None]:
   # plot the training loss
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(PLOT_PATH)

In [None]:
# serialize the model to disk
torch.save(model, MODEL_PATH)

## Predict

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2
import os

In [None]:
def prepare_plot(origImage, origMask, predMask):
	# initialize our figure
	figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
	
	# plot the original image, its mask, and the predicted mask
	ax[0].imshow(origImage)
	ax[1].imshow(origMask)
	ax[2].imshow(predMask)
	
	# set the titles of the subplots
	ax[0].set_title("Image")
	ax[1].set_title("Original Mask")
	ax[2].set_title("Predicted Mask")
	
	# set the layout of the figure and display it
	figure.tight_layout()
	figure.show()
	figure.savefig('predict.jpg')

In [None]:
def make_predictions(model, imagePath):
	# set model to evaluation mode
	model.eval()
	
	# turn off gradient tracking
	with torch.no_grad():
		# load the image from disk, swap its color channels, cast it
		# to float data type, and scale its pixel values
		image = cv2.imread(imagePath)
		image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
		image = image.astype("float32") / 255.0
		
		# resize the image and make a copy of it for visualization
		image = cv2.resize(image, (INPUT_IMAGE_WIDTH, INPUT_IMAGE_HEIGHT))
		orig = image.copy()
		
		# find the filename and generate the path to ground truth
		# mask
		# filename = imagePath.split(os.path.sep)[-1]
		# groundTruthPath = os.path.join(MASK_DATASET_PATH,
		# 	filename)
		
		groundTruthPath = imagePath.replace('image', 'mask')
		print(groundTruthPath)
			
		# load the ground-truth segmentation mask in grayscale mode
		# and resize it
		gtMask = cv2.imread(groundTruthPath, 0)
		gtMask = cv2.resize(gtMask, (INPUT_IMAGE_WIDTH,
			INPUT_IMAGE_HEIGHT))
			
		# make the channel axis to be the leading one, add a batch
		# dimension, create a PyTorch tensor, and flash it to the
		# current device
		image = np.transpose(image, (2, 0, 1))
		image = np.expand_dims(image, 0)
		image = torch.from_numpy(image).to(DEVICE)
		
		# make the prediction, pass the results through the sigmoid
		# function, and convert the result to a NumPy array
		predMask = model(image).squeeze()
		predMask = torch.sigmoid(predMask)
		predMask = predMask.cpu().numpy()
		
		# filter out the weak predictions and convert them to integers
		predMask = (predMask > THRESHOLD) * 255
		predMask = predMask.astype(np.uint8)
		
		# prepare a plot for visualization
		prepare_plot(orig, gtMask, predMask)

In [None]:
# load the image paths in our testing file and randomly select 10
# image paths
print("[INFO] loading up test image paths...")
imagePaths = open(TEST_PATHS).read().strip().split("\n")
imagePath = np.random.choice(imagePaths, size=1)[0]

print(imagePath)

In [None]:
# load our model from disk and flash it to the current device
print("[INFO] load up model...")
unet = torch.load(MODEL_PATH).to(DEVICE)

make_predictions(unet, imagePath)