In [None]:
#import coco dataset & annotations

!wget http://images.cocodataset.org/zips/train2017.zip
!unzip train2017.zip
!rm train2017.zip

!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
!unzip annotations_trainval2017.zip
!rm annotations_trainval2017.zip

In [None]:
# clone the code from github
!git clone https://github.com/walterwht/Unet_Video_segmentation

In [None]:
#install cocoapi 
#!git clone https://github.com/cocodataset/cocoapi.git
from pycocotools.coco import COCO
import pycocotools._mask as coco_mask

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchvision import transforms
from torchsummary import summary


import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

from Resnet50backbone_coco21classes import Resnet50Unet
from Dataset_Resnet50_coco21Classes import cocodataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#resequence 21 classes of the coco dataset

allclassid= [0,5, 2, 15, 9, 40, 6,
 3, 16, 57, 20, 61 , 17, 18, 4,
 1, 59, 19, 58, 7, 63]

allclassnms = ['__background__', 'airplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
 'car', 'cat', 'chair', 'cow', 'dining table', 'dog', 'horse', 'motorcycle',
 'person', 'potted plant', 'sheep', 'couch', 'train', 'tv']

for nac in range(len(allclassnms)):
  print("{}. class name:{}, class id:{}".format(nac,allclassnms[nac],allclassid[nac]))



In [None]:
#add data to dataloader 
batch_size=4
trainData = cocodataset(root = 'train2017',annFile = 'annotations/instances_train2017.json',classes = allclassnms, classid = allclassid)
data_loader = torch.utils.data.DataLoader(trainData, batch_size=batch_size, shuffle=True, num_workers=0)

print("Len of the datas".format(len(data_loader)))

In [None]:
# check the dataset image with mask
for e, data in enumerate(data_loader,start=15):
    fig2=plt.figure(figsize=(25,25)) 
    img, Tmask = data
    
    argm = torch.argmax(Tmask[0],dim=0)

    fig2.add_subplot(2,2,1)
    plt.imshow(img[0].permute(1,2,0), cmap="rainbow")

    fig2.add_subplot(2,2,2)
    plt.imshow(argm, cmap="rainbow")
    
    plt.show()
    break

In [None]:
#add model & import some pretaindata
model=Resnet50Unet(21)
model.load_state_dict(torch.load("n_classifier0727_21classes_3.pt"))
model = model.to(device)


In [None]:
# model checker
summary(model, input_size=(3,520,520))

In [None]:
# create custom loss function (not work great may be change to lovasz-Softmax will better)

#combined dice loss and CrossEntropyLoss 

def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

def calc_loss(pred, target, CEL_weight=0.3):
     
    criterion = nn.CrossEntropyLoss(ignore_index=0).to(device)
    Tmask = torch.argmax(target,dim=1)
    CEL =criterion(pred,Tmask)

    pred2 = torch.sigmoid(pred)
    dice = dice_loss(pred2, target)

    loss = CEL * CEL_weight + dice * (1 - CEL_weight)


    return loss

In [None]:
#Model tarining

#optimizer : turn off the gradient of the pretrain Resnet layer 
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.00005)

for epoch in range(2):  # loop over the dataset multiple times
    model.train() 
    running_loss = 0.0
  
    for eed, data  in enumerate(data_loader, start=0):
        optimizer.zero_grad()
        imgs, labels = data
        imgs, labels = imgs.to(device), labels.to(device) # to cuda if it available

        predict = model(imgs) # input image to model and get the predict
        loss = calc_loss(predict, labels)
        loss.backward()

        running_loss += loss.item() #save the loss for checking
        optimizer.step()

        if eed%200 == 199:
            print('[%d, %5d] loss: %.5f' % (epoch + 1, eed + 1, running_loss / 20))
            running_loss = 0.0


print('Finished Training')

In [None]:
# Save the trained data
path = "n_classifier0730_21classes.pt"
torch.save(model.state_dict(), path)

In [None]:
# import argmax tensor and output a color image

def argm_to_colorimg(masks):
    #each rgb color code for each classes
    colors = np.asarray([(0, 0, 0), (255, 255, 0), (255, 0, 255), (200, 0, 0),(200, 200, 0),
                         (200, 0, 200), (150, 0, 0), (150, 150, 0), (150, 0, 150),(100, 0, 0),
                         (100, 100, 0), (100, 0, 100), (50, 0, 0), (50, 50, 0),(50, 0, 50),
                         (0, 255, 0), (0, 255, 255), (0, 200, 0), (0, 200, 200),(0, 150, 0),
                         (0, 150, 150)])

    colorimg = np.zeros((masks.shape[0], masks.shape[1],3), dtype="float32")
    height, width = masks.shape

    for y in range(height):
        for x in range(width):
          colorimg[y][x]=colors[masks[y][x]]


    return colorimg.astype(np.uint8)

In [None]:
# test the segmentation by image
model.eval()


fig=plt.figure(figsize=(20, 20))

transform=transforms.Compose([
                    transforms.Resize(512),
                    transforms.RandomCrop(512),
                    transforms.ToTensor(),
])


img = Image.open("5.jpg").convert('RGB')
img = transform(img)
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)


with torch.no_grad():
    output=model(img)
    output=torch.sigmoid(output)
    outputF= output.data.cpu().squeeze(0).numpy()
    argm = np.argmax(outputF,axis=0)

    finalout =argm_to_colorimg(argm)
    
    fig.add_subplot(2,2,1)
    plt.imshow(finalout)
    plt.axis('off')

    fig.add_subplot(2,2,2)
    plt.imshow(argm)
    plt.axis('off')
    
    b = np.unique(argm)
    print(allclassnms[b.data]) # which classes in the argmax mask
        

    plt.show()