In [1]:
import os
import cv2
import torch
import numpy as np
import torchvision
from pathlib import Path
import torchvision.transforms as tf
import torchvision.models.segmentation

In [2]:
image_folder = Path.home()/'Documents/datasets/drone_segmentation_dataset/dataset/original_images/'
anns_folder = Path.home()/'Documents/datasets/drone_segmentation_dataset/dataset/label_images_semantic/'
images = os.listdir(image_folder)
anns = os.listdir(anns_folder)

In [3]:
#Set some hyperparameters
learning_rate=1e-5
width = 6000  #Image width
height = 4000 #Image height
batch_size=5

In [4]:
def sort_files(file_list):
    #The images labaled from '000' to '099' can be used for testing while those from '100' to '598' can be used 
    #for training the neural network
    #First convert to integers
    extension = file_list[0][3:]
    file_list_numbers = []; training_data = []
    test_data = []
    for file in file_list:
        if file[0] == '0':
            test_data.append(file)
            continue
        name_int = int(file[:3])
        file_list_numbers.append(name_int)
    file_list_numbers = sorted(file_list_numbers)
    for file in file_list_numbers:
        file = str(file) + extension
        training_data.append(file)
    del extension,file_list_numbers
    return training_data,test_data

In [5]:
#Data transform
def data_transform():
    image_transform=tf.Compose([
                    tf.ToPILImage(),
                    tf.Resize((height,width)), 
                    tf.ToTensor(),
                    tf.Normalize((0.485, 0.456, 0.406), 
                                 (0.229, 0.224, 0.225))])
    return image_transform
def ann_transform():
    ann_transform = tf.Compose([
                    tf.ToPILImage(),
                    tf.Resize((height,width)), 
                    tf.ToTensor()
                    ])
    return ann_transform
Img_transform = data_transform()
ANN_transform = ann_transform()

In [6]:
training_images,testing_images = sort_files(images)
training_anns,testing_anns = sort_files(anns)
print(training_images[:2])
print(training_anns[:2])
def ReadRandomImage(): # First lets load random image and  the corresponding annotation
    idx=np.random.randint(0,len(training_images)) # Select random image
    Img = cv2.imread(os.path.join(image_folder,training_images[idx]))[:,:,0:3]
    Ann = cv2.imread(os.path.join(anns_folder, training_anns[idx]),0)
    AnnMap = np.zeros(Img.shape[0:2],np.float32)
    for label in range(0,24):
        AnnMap[Ann == label] = label
    Img = Img_transform(Img)
    AnnMap = ANN_transform(AnnMap)
    return Img,AnnMap

['100.jpg', '101.jpg']
['100.png', '101.png']


In [7]:
def load_batch(): # Load batch of images
    images = torch.zeros([batch_size,3,height,width])
    ann = torch.zeros([batch_size, height, width])
    for i in range(batch_size):
        images[i],ann[i]=ReadRandomImage()
    return images, ann

In [8]:
#Load the device and the model of interest
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
Net = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True) # Load net
Net.classifier[4] = torch.nn.Conv2d(256, 24, kernel_size=(1, 1), stride=(1, 1)) # Change final layer to 3 classes
Net=Net.to(device)
optimizer=torch.optim.Adam(params=Net.parameters(),lr=learning_rate) # Create adam optimizer

In [None]:
for itr in range(1): # Training loop
   images,ann=load_batch() # Load taining batch
   print(type(ann))
   print(ann.size())
   images=torch.autograd.Variable(images,requires_grad=False).to(device) # Load image
   ann = torch.autograd.Variable(ann, requires_grad=False).to(device) # Load annotation
   Pred=Net(images)['out'] # make prediction
   Net.zero_grad()
   criterion = torch.nn.CrossEntropyLoss() # Set loss function
   Loss=criterion(Pred,ann.long()) # Calculate cross entropy loss
   Loss.backward() # Backpropogate loss
   optimizer.step() # Apply gradient descent change to weight
   seg = torch.argmax(Pred[0], 0).cpu().detach().numpy()  # Get prediction classes
   print("\n")
   print(type(seg))
   unique_values = np.unique(seg)
   print(unique_values)
   count_0 = np.count_nonzero(seg == 0)
   count_1 = np.count_nonzero(seg == 1)
   count_2 = np.count_nonzero(seg == 2)
   print(count_0)
   print(count_1)
   print(count_2)

<class 'torch.Tensor'>
torch.Size([5, 4000, 6000])
