**!!!!!!   Change the file path to your own   !!!!!!**

# Mapping
convert to gray image and check the corresponding class, then map them.

In [None]:
"""
usage: check unique RGB combination and mask them with different class indices
"""
from PIL import Image
import glob
import os
import matplotlib.pyplot as plt
import numpy as np

anno = Image.open('/home/renping/02456 DeepLearning Project/train/cropped/anno/3_0_anno.png')
anno_gray = Image.open('/home/renping/02456 DeepLearning Project/train/cropped/anno/3_0_anno.png').convert("L")

# plt.figure()
# plt.suptitle("anno 3_0 RGB and gray")
# plt.subplot(1,2,1)
# plt.imshow(anno)
# plt.title("3_0_anno_RGB")
# plt.subplot(1,2,2)
# plt.imshow(anno_gray)
# plt.title("3_0_anno_gray")
# plt.show()

print(anno_gray.mode)

anno_arr = np.array(anno_gray)
print(anno_arr)
print(type(anno_arr))
print(anno_arr.shape)
print(np.unique(anno_arr))

mapping = {
    0: 0,
    76: 0,
    149: 1,
    225: 2
}
# for k in mapping:
#     print(k)
for k in mapping:
    anno_arr[anno_arr==k] = mapping[k]

print(anno_arr)
print(np.unique(anno_arr))


plt.figure()
plt.suptitle("anno 3_0 RGB and gray")
plt.subplot(1,3,1)
plt.imshow(anno)
plt.title("3_0_anno_RGB")
plt.subplot(1,3,2)
plt.imshow(anno_gray)
plt.title("3_0_anno_gray")
plt.subplot(1,3,3)
plt.imshow(anno_arr)
plt.title("3_0_mapping")
plt.colorbar()
plt.show()

# Batch generator

In [None]:
"""
usage: batch generator for train data
input shape: (batch_size, num_channels, height, width)
target shape: (batch_size, height, width)
"""
import torch
from torch.utils import data
from torch.utils.data import DataLoader
import os
from PIL import Image
import numpy as np
from torchvision import transforms as T
import torchvision



class Custom_Data(data.Dataset):
    def __init__(self, path_train, path_anno, transforms = None, normalize = None):
        raw_img = os.listdir(path_train) # raw images
        anno_img = os.listdir(path_anno) # anno images
        self.raw_img = [os.path.join(path_train, img) for img in raw_img]
        self.anno_img = [os.path.join(path_anno, img) for img in anno_img]
        self.transform = transforms
        # To Do: Data Augmentation #
        self.mapping = {
            0:0,
            76:0,
            149:1,
            225:2,
        }
    def mask_to_class(self, mask): # "mask" is the anno image after converting to gray using "L" mode
        for k in self.mapping:
            mask[mask == k] = self.mapping[k]
        return mask


    def __getitem__(self, index):
        raw_img_path = self.raw_img[index]
        raw_data = Image.open(raw_img_path)
        raw_data = self.transform(raw_data)

        anno_img_path = self.anno_img[index]
        anno_data = Image.open(anno_img_path)

        anno_data = self.transform(anno_data)
        
        anno_data = T.ToPILImage()(anno_data).convert("RGB")
        anno_data = anno_data.convert("L")

        anno_data = torch.from_numpy(np.array(anno_data)) # np.array -> torch tensor
        anno_data = self.mask_to_class(anno_data)
        
        return raw_data, anno_data

    def __len__(self):
        return len(self.raw_img)

# the path of your train data
path_train = '/home/renping/02456 DeepLearning Project/train/cropped/raw'
# the path of your corresponding annotated data
path_anno = '/home/renping/02456 DeepLearning Project/train/cropped/anno'

# normMean, normStd = Compute_Mean_Std(path_train)
transform = T.Compose([
    T.RandomCrop(256, padding=4),
    T.RandomRotation(30),
    T.ToTensor(),
])
# norm = T.Normalize(normMean, normStd, inplace=True)

train_data = Custom_Data(path_train, path_anno, transforms=transform, normalize=None)
train_loader = DataLoader(train_data, batch_size=5, shuffle=True)

# train_iter = iter(train_loader)
# raw_data, anno_data = train_iter.next()
# print("raw dataset size = {}.".format(raw_data.size()))
# print("anno dataset size = {}.".format(anno_data.size()))

# Train
traing and simple testing

In [None]:
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt


from SegNet import SegNet
from batch_generator import train_loader

# initialize the writter
log_dir="/home/renping/02456 DeepLearning Project/visualization"
writer = SummaryWriter(log_dir=log_dir)

# initialize the net, loss, optimizer
use_cuda = torch.cuda.is_available()
if use_cuda:
    print("Running on GPU!")
    net = SegNet(input_channels=3, output_classes=3).cuda()
    criterion = nn.CrossEntropyLoss().cuda()
else:
    print("Running on CPU!")
    net = SegNet(input_channels=3, output_classes=3)
    criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)


num_epoch =20
print("Start training!")
print("Total training epoch: {}.".format(num_epoch))

for epoch in range(0, num_epoch):
    training_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, targets = data
        targets = targets.type(torch.int64)
        if use_cuda:
            inputs = inputs.cuda()
            targets = targets.cuda()

        optimizer.zero_grad()
        # forward
        outputs = net.forward(inputs)
        loss = criterion(outputs, targets)
        # backward
        loss.backward()
        optimizer.step()
        training_loss = training_loss+loss.item()
        if (i%5 == 4): # check the loss for every 5 mini batch
            print("epoch = {},mini batch = {} loss: {}".format(epoch+1, i+1, loss.item()))
    # check the loss for every epoch
    print("epoch = {}, loss: {}".format(epoch+1, training_loss))
    writer.add_scalar("train/loss", training_loss, epoch)
writer.close()
print("Finished Training")




print("*********************************************************")
from PIL import Image
import torch
import numpy as np

print("**********start predicting**********")

img = np.array(Image.open('/home/renping/02456 DeepLearning Project/test/cropped/raw/1_2_raw.png'))

print(img.shape)

img = torch.from_numpy(img).type(torch.FloatTensor).cuda()
img = img.permute(2,0,1)
print(img.shape)

img = img.unsqueeze(0)

print(img.shape)

pred = net(img)

print(pred.shape)

pred = pred.squeeze()

test = torch.argmax(pred, dim=0).cpu().numpy()

print(test.shape)

plt.imshow(test)
plt.show()



