In [None]:
from __future__ import print_function, division
import os
import torch
import pandas as pd
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import torch.nn.functional as F
from torchvision import transforms, utils
from IPython.display import clear_output, display

In [None]:
import warnings
warnings.filterwarnings("ignore")
plt.ion()

In [None]:
denom = 1
image_width = 128
image_height = 128

In [None]:
train_split = 0.8

In [None]:
def rle_decode(rle_image):
    image_array = np.zeros(image_width*image_height)
    for rle_tuple in rle_image:
        number_of_pixels = int(rle_tuple[1]) - 1
        pixel = int(rle_tuple[0])
        image_array[pixel:pixel + number_of_pixels] = 1
#         for i in range(number_of_pixels):
#             pixel = int(pixel)
#             image_array[pixel + i] = 1
    return np.reshape(image_array, (image_width, image_height), order='F')

In [None]:
ships_frame = pd.read_csv('data/train_ship_segmentations_v2.csv')

n = 3
img_name = ships_frame.iloc[n, 0]
ships = ships_frame.iloc[n, 1].split(' ')
print(ships)
ships = np.asarray(ships).reshape(-1, 2)
ships = rle_decode(ships)

print('Image name: {}'.format(img_name))
print('Landmarks shape: {}'.format(ships.shape))

In [None]:


print(ships[:4])
print(rle_decode(ships[:4]))

In [None]:
def show_ship_images(image, ships):
    plt.subplot(1,2,1)
    plt.imshow(image)
    plt.subplot(1,2,2)
    plt.imshow(ships, 'gray')
    plt.pause(0.001)  # pause a bit so that plots are updated

plt.figure()
print(img_name)
show_ship_images(io.imread(os.path.join('data/train_v2', img_name)),
               ships)
plt.show()

In [None]:
class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __init__(self, dimensions=None):
        self.dimensions = dimensions
        
    def __call__(self, sample):
        image, loc_image = sample['image'], sample['loc_image']
        # swap color axis because
        # numpy image: H x W x C
        # torch image: C X H X W
        if self.dimensions:
            height, width = self.dimensions
            image = transform.resize(image, (height, width), anti_aliasing=True, preserve_range=True)
            loc_image = transform.resize(loc_image, (height, width), anti_aliasing=False, preserve_range=True)
            loc_image[loc_image > 0] = 255
            
        image = image.transpose((2, 0, 1)).astype(np.float32) / 255
        loc_image = loc_image.reshape(1, loc_image.shape[0], loc_image.shape[1]).astype(np.float32) / 255
        contains_ship = np.asarray([1, 0], dtype=np.float32) if np.count_nonzero(loc_image) == 0 else np.asarray([0,1],dtype=np.float32)
        return {'image': torch.from_numpy(image),
                'loc_image': torch.from_numpy(loc_image),
                'contains_ship': torch.from_numpy(contains_ship)}

In [None]:
class ShipData(Dataset):
    def __init__(self):
        self.ships_dataframe = pd.read_csv('data/train_ship_segmentations_v2.csv')

    def __len__(self):
        return len(self.ships_dataframe)
    
    def __getitem__(self, idx):
        image_name = self.ships_dataframe.iloc[idx, 0]
        image_path = 'data/train_v2/' + image_name
        image = io.imread(image_path)
        if pd.isnull(self.ships_dataframe.iloc[idx, 1]):
            ship_loc_data = []
        else:
            ship_loc_data = self.ships_dataframe.iloc[idx, 1]
            ship_loc_data = ship_loc_data.split(' ')
            ship_loc_data = np.asarray(ship_loc_data).reshape(-1, 2)
        
        loc_image = rle_decode(ship_loc_data)
        
        sample = {'image': image, 'loc_image': loc_image}
        return sample
                

class MergeShipData(Dataset):
    def __init__(self):
        self.ships_dataframe = pd.read_csv('data/train_ship_segmentations_v2.csv')
        self.merge_ships_dict = self._merge_ships(self.ships_dataframe)
        self.merge_ships_dataframe = self.ships_dataframe.drop('EncodedPixels', 1).drop_duplicates()
        self.transform = transforms.Compose([ToTensor((image_height, image_width))])

    def __len__(self):
        return len(self.merge_ships_dict) - 1
    
    def _get_blank_sample(self):
        image = np.zeros((image_height, image_width, 3))
        loc_image = rle_decode([])
        return {"image": image, "loc_image": loc_image, 'contains_ship': 0}
    
    def _get_image_mask(self, idx):
        image_name = self.ships_dataframe.iloc[idx, 0]
        if pd.isnull(self.ships_dataframe.iloc[idx, 1]):
            ship_loc_data = []
        else:
            ship_loc_data = self.ships_dataframe.iloc[idx, 1]
            ship_loc_data = ship_loc_data.split(' ')
            ship_loc_data = np.asarray(ship_loc_data).reshape(-1, 2)
        
        loc_image = rle_decode(ship_loc_data)
        return loc_image
        
    def __getitem__(self, idx):
        try:
            image_name = self.merge_ships_dataframe.iloc[idx, 0]
            image_idxs = self.merge_ships_dict[image_name]
            image_path = 'data/train_v2/' + image_name
            image = io.imread(image_path)
            loc_image = self._merge_images(image_idxs)
            sample = {'image': image, 'loc_image': loc_image }
        except:
            sample = self._get_blank_sample()
        return self.transform(sample)
    
    def _merge_ships(self, df):
        ships = {}
        for idx, row in df.iterrows():
            image_name = row[0]
            if image_name in ships:
                ships[image_name].append(idx)
            else:
                ships[image_name] = [idx]
        return ships
    
    def _merge_images(self, idxs):
        img = rle_decode([])
        for idx in idxs:
            img = np.logical_or(img, self._get_image_mask(idx)).astype(float)

        return img.astype(np.uint8) * 255
            
        

In [None]:
sd = MergeShipData()

In [None]:
len(sd)

In [None]:
# last = 0
# for i in range(len(sd)):
#     last = i
#     print(i)
#     sd[i]
#     clear_output(wait=True)

In [None]:
sd[1]

In [None]:
def plot_sample(sample):
    toPilImage = transforms.ToPILImage()
    to1ChannelPilImage = transforms.ToPILImage()
    image = toPilImage(sample['image'])
    loc_image = to1ChannelPilImage(sample['loc_image'])
    plt.figure(num=None, figsize=(12, 6), dpi=80, facecolor='w', edgecolor='k')
    plt.subplot(1,2,1)
    plt.imshow(image)
    plt.subplot(1,2,2)
    plt.imshow(loc_image, 'gray')
    plt.pause(0.001)  # pause a bit so that plots are updated

In [None]:
number_to = 40
number_from = 0
for i in range(number_from, number_to):
    plot_sample(sd[i])

In [None]:
def generate_train_and_test_idxs(dataset_length, train_split):
    np.random.seed(42)
    idxs = list(range(dataset_length))
    split = int(np.floor(train_split * dataset_length))
    np.random.shuffle(idxs)
    return idxs[:split], idxs[split:]

def get_train_and_test_datasets(dataset, max_length=None, batch_size=8):
    dataset_length = max_length or len(dataset)
    train_indices, val_indices = generate_train_and_test_idxs(dataset_length, train_split)
    train_sampler = SubsetRandomSampler(train_indices)
    validation_sampler = SubsetRandomSampler(val_indices)
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, 
                                               sampler=train_sampler)
    validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                    sampler=validation_sampler)
    return train_loader, validation_loader
    

In [None]:
class Flatten(torch.nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

In [None]:
# read https://arxiv.org/pdf/1505.04366.pdf
# perhaps combine with https://arxiv.org/pdf/1311.2524.pdf
class MaskModel(torch.nn.Module):
    def __init__(self):
        super(MaskModel, self).__init__()
        
        # conv layers
        self.cnn1 = torch.nn.Conv2d(3, 16, 3, padding=1)
        self.cnn2 = torch.nn.Conv2d(16, 32, 3, padding=1)        
        self.cnn3 = torch.nn.Conv2d(32, 64, 3, padding=1)        
        self.cnn4 = torch.nn.Conv2d(64, 128, 3, padding=1)

        
        # deconv layers
        self.decnn1 = torch.nn.ConvTranspose2d(128, 64, 3, padding=1)
        self.decnn2 = torch.nn.ConvTranspose2d(64, 32, 3, padding=1)
        self.decnn3 = torch.nn.ConvTranspose2d(32, 16, 3, padding=1)
        self.decnn4 = torch.nn.ConvTranspose2d(16, 1, 3, padding=1)
        
        
        # pooling layers
        self.pool2d = torch.nn.MaxPool2d(2, stride=2, return_indices=True)
        self.unpool2d = torch.nn.MaxUnpool2d(2, stride=2)
        
        # activation layers
        self.relu = torch.nn.ReLU()
        self.sigmoid = torch.nn.Sigmoid()
        self.softmax = torch.nn.Softmax2d()
        
        # other layers
        self.dropout = torch.nn.Dropout2d()
        self.batchnorm1 = torch.nn.BatchNorm2d(16)        
        self.batchnorm2 = torch.nn.BatchNorm2d(32)        
        self.batchnorm3 = torch.nn.BatchNorm2d(64)        
        self.batchnorm4 = torch.nn.BatchNorm2d(128)



        
    def forward(self, x):
        result = self.cnn1(x)
        result = self.batchnorm1(result)
        result = self.relu(result)
        result, indices1 = self.pool2d(result)
        result = self.cnn2(result)
        result = self.batchnorm2(result)
        result = self.relu(result)
        result, indices2 = self.pool2d(result)
        result = self.cnn3(result)
        result = self.batchnorm3(result)
        result = self.relu(result)
        result, indices3 = self.pool2d(result)
        result = self.cnn4(result)
        result = self.batchnorm4(result)
        result = self.relu(result)
        result, indices4 = self.pool2d(result)
        
        result = self.unpool2d(result, indices4)
        result = self.decnn1(result)
        result = self.batchnorm3(result)
        result = self.relu(result)
        result = self.unpool2d(result, indices3)
        result = self.decnn2(result)
        result = self.batchnorm2(result)
        result = self.relu(result)
        result = self.unpool2d(result, indices2)
        result = self.decnn3(result)
        result = self.batchnorm1(result)
        result = self.relu(result)
        result = self.unpool2d(result, indices1)
        result = self.decnn4(result)
        result = self.sigmoid(result)
        return result
    
    

In [None]:
def create_categorization_model():
    model = torch.nn.Sequential(
        torch.nn.Conv2d(3, 16, 3),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2, stride=2),
        torch.nn.Conv2d(16, 20 , 3),
        torch.nn.Dropout2d(),
        torch.nn.ReLU(),
        torch.nn.MaxPool2d(2, stride=2),
#         torch.nn.Conv2d(64, 128, 3, padding=1),
#         torch.nn.ReLU(),
#         torch.nn.MaxPool2d(2, stride=2),
#         torch.nn.Conv2d(128, 256, 3, padding=1),
#         torch.nn.ReLU(),
#         torch.nn.MaxPool2d(2, stride=2),
#         torch.nn.Conv2d(256, 128, 3, padding=1),
#         torch.nn.ReLU(),
#         torch.nn.MaxPool2d(2, stride=2),
#         torch.nn.Conv2d(128, 256, 3, padding=1),
#         torch.nn.ReLU(),
#         torch.nn.MaxPool2d(2, stride=2),
        Flatten(),
        torch.nn.Linear((94*94*20), 1024),
        torch.nn.ReLU(),
        torch.nn.Linear(1024, 512),
        torch.nn.ReLU(),
        torch.nn.Linear(512, 2),
        torch.nn.Softmax()
    )
    return model

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
model = MaskModel()

In [None]:
model = model.to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.BCELoss()

In [None]:
loss_history = list()

In [None]:
train_loader, validation_loader = get_train_and_test_datasets(sd, max_length=1000000, batch_size=64)

In [None]:
num_epochs = 1
model.train()
for epoch in range(num_epochs):
    # Train:   
    for batch_index, sample in enumerate(train_loader):
        
        image = sample['image'].to(device)
        labels = sample['loc_image'].to(device)
        outputs = model(image)
        loss = criterion(outputs, labels.view(32, 128, -1))
        loss_history.append(loss.item())
        plt.plot(loss_history, color='blue')
        plt.ylabel('loss')
        #plt.show()
        #plt.plot(loss_history[len(loss_history) - 24:])
        #plt.ylabel('loss in last 24 batches')
        clear_output(wait=True)
        plt.show()
        print(f'loss: {loss.item()}\nbatch: {batch_index}\nepoch: {epoch}')
        
        
        #display(plt.gcf())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
model.eval()

with torch.no_grad():
    correct = 0
    total = 0
    for index, sample in enumerate(validation_loader):
        image = sample['image'].to(device)
        labels = sample['contains_ship'].to(device)
        outputs = model(image)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted.cpu().data.numpy().astype(np.float32) == labels[:,0]).sum().item()

    print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

In [None]:
toPilImage = transforms.ToPILImage()
device = torch.device('cpu')
model = model.to(device)
model.eval()
image = sd[50]['image'].to(device)
output = model(torch.unsqueeze(image, 0))[0]

In [None]:
toPilImage((output * (255/(output.max()))))

In [None]:
toPilImage(image)

In [None]:
def infer_and_compare_result(model, sample):
    model.eval()
    image = sample['image'].to(device)
    label = sample['loc_image'].to(device)
    output = model(torch.unsqueeze(image, 0))[0]
    
    
    toPilImage = transforms.ToPILImage()
    image = toPilImage(sample['image'])
    loc_image = toPilImage(sample['loc_image'])
    output_image = toPilImage((output * 255/output.max()))
    plt.figure(num=None, figsize=(32, 12), dpi=80, facecolor='w', edgecolor='k')
    plt.subplot(1,3,1)
    plt.imshow(image)
    plt.subplot(1,3,2)
    plt.imshow(loc_image, 'gray')
    plt.subplot(1,3,3)
    plt.imshow(output_image, 'gray')
    
    plt.pause(0.001)  # pause a bit so that plots are updated

In [None]:
print(sd[50]['loc_image'].max())
infer_and_compare_result(model, sd[50])

In [None]:
plt.plot(loss_history)
plt.ylabel('loss')
plt.show()