Runtime > Change Runtime Type > Select GPU > Save

Then run this cell to make pytorch use the GPU

In [None]:
"""
>>>>> DETERMINE EXECUTION DEVICE <<<<<
"""
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(torch.cuda.is_available())
print('device count', torch.cuda.device_count())
print('current', torch.cuda.current_device())
print('GPU', torch.cuda.get_device_name(0))


Zip the data folder and upload as data.zip, then run this cell to unzip it

In [None]:
"""
>>>>> COLAB UNZIPPING <<<<<
(optional)
"""
#!unzip data.zip

Run the next 4 cells to start training

In [None]:
"""
>>>>> HELPER FUNCTIONS <<<<<
"""
from matplotlib import pyplot as plt
def show_image_mask(img, mask, cmap='gray'): # visualisation
    fig = plt.figure(figsize=(5,5))
    plt.subplot(1, 2, 1)
    plt.imshow(img, cmap=cmap)
    plt.axis('off')
    plt.subplot(1, 2, 2)
    plt.imshow(mask, cmap=cmap)
    plt.axis('off')
    plt.show() # draw the images immediatelly

In [None]:
"""
>>>>> DATA LOADERS <<<<<
(including fix for linux glob)
"""
import torch
import torch.utils.data as data
import cv2
import os
from glob import glob
import natsort

class TrainDataset(data.Dataset):
    def __init__(self, root=''):
        super(TrainDataset, self).__init__()
        self.img_files = glob(os.path.join(root,'image','*.png'))
        #self.img_files = self.img_files[0:10] # only using part of the dataset
        self.mask_files = []             
        for img_path in self.img_files:
            basename = os.path.basename(img_path)
            self.mask_files.append(os.path.join(root,'mask',basename[:-4]+'_mask.png'))
            

    def __getitem__(self, index):
            img_path = self.img_files[index]
            mask_path = self.mask_files[index]
            data = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            label = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
            return torch.from_numpy(data).float(), torch.from_numpy(label).float()

    def __len__(self):
        return len(self.img_files)

    

    

class TestDataset(data.Dataset):
    def __init__(self, root=''):
        super(TestDataset, self).__init__()
        #self.img_files = glob(os.path.join(root,'image','*.png'))
        self.img_files = natsort.natsorted(glob(os.path.join(root,'image','*.png')))

    def __getitem__(self, index):
            img_path = self.img_files[index]
            print("get "+img_path)
            data = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            return torch.from_numpy(data).float()

    def __len__(self):
        return len(self.img_files)

    def __insert__(self, newimgfiles, newmaskfiles): 
        #self.img_files = torch.cat[self.img_files, newimgfiles]
        #self.mask_files = torch.cat[self.mask_files, newmaskfiles]
        return 1

    


U-Net Refined Version:

In [None]:
"""
>>>>> MODEL DEFINITION <<<<<
"""

import torch
from torch import nn
import torch.nn.functional as F


class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()

        kernel_size = 5
        padding = 2
        dropout_rate = 0.5

        # Down Layers

        self.enc_11 = nn.Sequential(*[
            nn.Conv2d(in_channels=1,
                      out_channels=64,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)])

        self.enc_12 = nn.Sequential(*[
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        ])

        self.enc_21 = nn.Sequential(*[
            nn.Conv2d(in_channels=64,
                      out_channels=128,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)])

        self.enc_22 = nn.Sequential(*[
            nn.Conv2d(in_channels=128,
                      out_channels=128,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        ])

        self.enc_31 = nn.Sequential(*[
            nn.Conv2d(in_channels=128,
                      out_channels=256,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)])

        self.enc_32 = nn.Sequential(*[
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.5)
        ])

        self.enc_41 = nn.Sequential(*[
            nn.Conv2d(in_channels=256,
                      out_channels=512,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)])

        self.enc_42 = nn.Sequential(*[
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        ])

        self.enc_51 = nn.Sequential(*[
            nn.Conv2d(in_channels=512,
                      out_channels=1024,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True)
        ])

        self.enc_52 = nn.Sequential(*[
            nn.Conv2d(in_channels=1024,
                      out_channels=1024,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.5)])

        # Up Layers

        # 4u for if bilinear, or 4t for convolution
        self.dec_4t = nn.Sequential(*[nn.ConvTranspose2d(in_channels=1024, out_channels=1024, kernel_size=2, stride=2)])
        self.dec_4u = nn.Sequential(*[nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)])

        self.dec_41 = nn.Sequential(*[
            nn.Conv2d(in_channels=1024,
                      out_channels=512,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        ])

        self.dec_42 = nn.Sequential(*[
            nn.Conv2d(in_channels=1024,
                      out_channels=512,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=512,
                      out_channels=512,
                      kernel_size=kernel_size, padding=padding),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout2d(dropout_rate)
        ])

        self.dec_3t = nn.Sequential(*[nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=2, stride=2)])
        self.dec_3u = nn.Sequential(*[nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)])

        self.dec_31 = nn.Sequential(*[
            nn.Conv2d(in_channels=512,
                      out_channels=256,
                      kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)])

        self.dec_32 = nn.Sequential(*[
            nn.Conv2d(in_channels=512,
                      out_channels=256,
                      kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256,
                      out_channels=256,
                      kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.5)
        ])

        self.dec_2t = nn.Sequential(*[nn.ConvTranspose2d(in_channels=256, out_channels=256, kernel_size=2, stride=2)])
        self.dec_2u = nn.Sequential(*[nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)])

        self.dec_21 = nn.Sequential(*[
            nn.Conv2d(in_channels=256,
                      out_channels=128,
                      kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)])

        self.dec_22 = nn.Sequential(*[
            nn.Conv2d(in_channels=256,
                      out_channels=128,
                      kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128,
                      out_channels=128,
                      kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.5)])

        self.dec_1t = nn.Sequential(*[nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=2, stride=2)])
        self.dec_1u = nn.Sequential(*[nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)])

        self.dec_11 = nn.Sequential(*[
            nn.Conv2d(in_channels=128,
                      out_channels=64,
                      kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)])

        self.dec_12 = nn.Sequential(*[
            nn.Conv2d(in_channels=128,
                      out_channels=64,
                      kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=64,
                      out_channels=64,
                      kernel_size=kernel_size, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.5)])

        self.final = nn.Conv2d(64, 4, kernel_size=1)

        self.init_kaiming_weights()

    def forward(self, x):
        enc11 = self.enc_11(x)
        # print("\nEnc1: ", enc11.shape)
        enc12 = self.enc_12(enc11)
        pool1 = F.max_pool2d(enc12, kernel_size=2, stride=2)

        enc21 = self.enc_21(pool1)
        # print("Enc2: ", enc21.shape)
        enc22 = self.enc_22(enc21)
        pool2 = F.max_pool2d(enc22, kernel_size=2, stride=2)

        enc31 = self.enc_31(pool2)
        # print("Enc3: ", enc31.shape)
        enc32 = self.enc_32(enc31)
        pool3 = F.max_pool2d(enc32, kernel_size=2, stride=2)

        enc41 = self.enc_41(pool3)
        # print("Enc3: ", enc41.shape)
        enc42 = self.enc_42(enc41)
        pool4 = F.max_pool2d(enc42, kernel_size=2, stride=2)

        enc51 = self.enc_51(pool4)
        # print("Enc3: ", enc51.shape)
        enc52 = self.enc_52(enc51)

        dec4t = self.dec_4u(enc52)
        # print("Dec4t: ", dec4t.shape)
        dec41 = self.dec_41(dec4t)
        # print("Dec41: ", dec41.shape)
        dec4c = torch.cat([dec41, enc42], dim=1)
        # print("Dec4c: ", dec4c.shape)
        dec42 = self.dec_42(dec4c)
        # print("Dec42: ", dec42.shape)

        dec3t = self.dec_3u(enc42)
        # print("Dec3t: ", dec3t.shape)
        dec31 = self.dec_31(dec3t)
        # print("Dec31: ", dec31.shape)
        dec3c = torch.cat([dec31, enc32], dim=1)
        # print("Dec3c: ", dec3c.shape)
        dec32 = self.dec_32(dec3c)
        # print("Dec32: ", dec32.shape)

        dec2t = self.dec_2u(dec32)
        # print("Dec2t: ", dec2t.shape)
        dec21 = self.dec_21(dec2t)
        # print("Dec21: ", dec21.shape)
        dec2c = torch.cat([dec21, enc22], dim=1)
        # print("Dec2c: ", dec2c.shape)
        dec22 = self.dec_22(dec2c)
        # print("Dec22: ", dec22.shape)

        dec1t = self.dec_1u(dec22)
        # print("Dec1t: ", dec1t.shape)
        dec11 = self.dec_11(dec1t)
        # print("Dec11: ", dec11.shape)
        dec1c = torch.cat([dec11, enc12], dim=1)
        # print("Dec1c: ", dec1c.shape)
        dec12 = self.dec_12(dec1c)
        # print("Dec12: ", dec12.shape)

        return self.final(dec12)

    def init_kaiming_weights(self):
        # Encoder 1
        nn.init.kaiming_normal_(self.enc_11[0].weight)
        nn.init.kaiming_normal_(self.enc_12[0].weight)
        # Encoder 2
        nn.init.kaiming_normal_(self.enc_21[0].weight)
        nn.init.kaiming_normal_(self.enc_22[0].weight)
        # Encoder 3
        nn.init.kaiming_normal_(self.enc_31[0].weight)
        nn.init.kaiming_normal_(self.enc_32[0].weight)

        # Encoder 4
        nn.init.kaiming_normal_(self.enc_41[0].weight)
        nn.init.kaiming_normal_(self.enc_42[0].weight)

        # Encoder 5
        nn.init.kaiming_normal_(self.enc_51[0].weight)
        nn.init.kaiming_normal_(self.enc_52[0].weight)

        # Decoder 4
        nn.init.kaiming_normal_(self.dec_41[0].weight)
        nn.init.kaiming_normal_(self.dec_42[0].weight)

        # Decoder 3
        nn.init.kaiming_normal_(self.dec_31[0].weight)
        nn.init.kaiming_normal_(self.dec_32[0].weight)

        # Decoder 2
        nn.init.kaiming_normal_(self.dec_21[0].weight)
        nn.init.kaiming_normal_(self.dec_22[0].weight)

        # Decoder 1
        nn.init.kaiming_normal_(self.dec_11[0].weight)
        nn.init.kaiming_normal_(self.dec_12[0].weight)

        # Final
        nn.init.kaiming_normal_(self.final.weight)
        

Train and Eval things: 

In [None]:
"""
>>>>> DICE SCORE COMPUTATION <<<<<
(modified from tutorial, added optional printing, removed background score)
"""
import numpy as np

def categorical_dice(mask1, mask2, label_class=1):
    """
    Dice score of a specified class between two volumes of label masks.
    (classes are encoded but by label class number not one-hot )
    Note: stacks of 2D slices are considered volumes.

    Args:
        mask1: N label masks, numpy array shaped (H, W, N)
        mask2: N label masks, numpy array shaped (H, W, N)
        label_class: the class over which to calculate dice scores

    Returns:
        volume_dice
    """
    mask1_pos = (mask1 == label_class).astype(np.float32)
    mask2_pos = (mask2 == label_class).astype(np.float32)
    denom = (np.sum(mask1_pos) + np.sum(mask2_pos))
    if(int(denom) == 0):
        return 0
    dice = 2 * np.sum(mask1_pos * mask2_pos) / (np.sum(mask1_pos) + np.sum(mask2_pos))
    return dice

def dice_class_score(mask1, mask2):
    dice_scores = [
        categorical_dice(mask1, mask2, 0),
        categorical_dice(mask1, mask2, 1),
        categorical_dice(mask1, mask2, 2),
        categorical_dice(mask1, mask2, 3)
    ]
    return dice_scores

def average_dice(mask1, mask2, verbose):
    dice_scores = dice_class_score(mask1, mask2)
    if(verbose):
        for i in range(len(dice_scores)):
            print("=> Class {} = {}".format(i+1,dice_scores[i]))   # i+1 because not including background class

    total_dice = sum(dice_scores) / len(dice_scores)
    return total_dice



In [None]:
"""
>>>>> TRAINING & EVAL FUNCTION <<<<<
"""
import torch.optim as optim
from torch.utils.data import DataLoader

def run_validation(model,dataloader):
    model.eval() # switch model to evaluation mode
    dices = []
    with torch.no_grad():
       
        # Fetch images and labels.
        for iteration, sample in enumerate(dataloader):
            img, mask = sample

            # forward
            img = img.unsqueeze(1)
            img = img.to(device)    # move to gpu during training/validation
            outputs = model(img)

            mask = mask.type(torch.LongTensor)
            # convert output to predicted class so it can be visualised
            mask_pred = torch.argmax(outputs, dim=1).detach().cpu()

            numsamples = mask.shape[0]  # how many images to dice score for
            for i in range(numsamples):
                dice = dice_class_score(mask[i,...].numpy(), mask_pred[i,...].numpy())
                dices.append(dice)
    return dices

def avg_and_save_validation(filename,dices):
    numclasses = len(dices[0])
    totals = [0 for i in range(numclasses)]
    for scores in dices:
        for i in range(len(scores)):
            totals[i] += scores[i]
    
    numdices = len(dices)
    classavg_dices = [t / numdices for t in totals]
    dice_score = sum(classavg_dices) / len(classavg_dices)

    stringdices = [str(d) for d in classavg_dices]
    with open(filename,"a") as f:
        line = ", ".join(stringdices)
        print(line)
        f.write(line+"\n")
    return dice_score
    

def train_eval():
    model = UNet()
    # Define the model
    model = UNet()
    epochs = 200
    lr = 0.01
    loss_fn = torch.nn.CrossEntropyLoss()
    optimiser = optim.Adam(model.parameters(), lr=lr)

    # CUDA Setup
    device = torch.device("cpu")
    if (torch.cuda.is_available()):
        print("CUDA")
        device = torch.device("cuda")
    # Pass the model and loss function to the device
    model = model.to(device)
    loss_fn = loss_fn.to(device)

    # Setup DataLoaders for both the training and validation data
    train_data_path = './data/train'
    validate_data_path = './data/val'
    num_workers = 4
    batch_size = 10
    # Seperate batch_size cause I was struggling to index masks correctly and too lazy to change it now
    val_batch_size = 2
    train_set = TrainDataset(train_data_path)
    validate_set = TrainDataset(validate_data_path)
    validation_data_loader = DataLoader(dataset=validate_set, num_workers=num_workers, batch_size=val_batch_size,
                                        shuffle=True)
    training_data_loader = DataLoader(dataset=train_set, num_workers=num_workers, batch_size=batch_size,
                                      shuffle=True)


    print('=> Training Data Length:' , len(training_data_loader))
    best_dice = 0
    dice_score = 0

    train_log = []
    validation_log = []

    CSV_HEADER = "class0, class1, class2, class3\n"
    with open("trainlog.csv","w") as f:
            f.write(CSV_HEADER)
    with open("validlog.csv","w") as f:
            f.write(CSV_HEADER)
    
    for epoch in range(epochs):
        
        model.train()
        ##### TRAIN CYCLE #####
        # Fetch images and labels.
        for iteration, sample in enumerate(training_data_loader):
            
            # Read in sample image and mask
            img, mask = sample
            
            # Setup image so its in the format the model is expecting
            inp_img = img.view(batch_size, 1, 96, 96)

            inp_img = inp_img.to(device)

            optimiser.zero_grad()

            mask_pred = model(inp_img)

            mask_pred = mask_pred.view(batch_size, 4, 96, 96)

            # Loss expects long target so change it here
            mask = mask.long()
            mask = mask.to(device)

            loss = loss_fn(mask_pred, mask)
            loss.backward()
            optimiser.step()
        
        ##### VALIDATE EVERY EPOCH #####
        
        dices = run_validation(model,training_data_loader)
        print("Epoch:",epoch,"train raw: ",end='')
        dice_score = avg_and_save_validation("trainlog.csv",dices)
        print("Epoch:",epoch, "training dice:  ", dice_score)

        

        dices = run_validation(model,validation_data_loader)
        print("Epoch:",epoch,"valid raw: ",end='')
        dice_score = avg_and_save_validation("validlog.csv",dices)
        print("Epoch:",epoch, "validation dice:", dice_score)

        ##### SAVE BEST #####
        if(dice_score > best_dice):
            print("=> New Personal Best! {}".format(dice_score))
            print("=> Saving...")
            best_dice = dice_score
            with open("./best_dice.txt","w") as f:
                f.write("avg_dice={}\n".format(best_dice))
                f.write("epoch={}\n".format(epoch))
                f.write("totalepochs={}\n".format(epochs))
                f.write("lr={}\n".format(lr))
                f.write("train_data={}\n".format(train_data_path))
                f.write("validation_data={}\n".format(validate_data_path))
                f.write("train_batch_size={}\n".format(batch_size))
            PATH = './trained_best_dice.pth'.format(epoch)
            torch.save(model.state_dict(), PATH)
            print("=> Saved Personal Best")


        # Every 25 epochs calculate loss on validation set
        if epoch % 25 == 0 and epoch != 0:
            # Save the model for later testing
           # PATH = './final/UNet/UNet_Adam_1e-5_' + str(epoch) + '.pth'
            PATH = './trained_{0}e.pth'.format(epoch)
            torch.save(model.state_dict(), PATH)
            

In [None]:
"""
>>>>> RUN TRAINING WITH EVALUATION <<<<<
"""
train_eval()


In [None]:
"""
>>>>> LOAD SAVED MODEL <<<<<
"""
model = UNet()
#model.load_state_dict(torch.load("sgd_training/trained_99e.pth"))  # this is cool
model.load_state_dict(torch.load("trained_best_dice.pth"))
#model.load_state_dict(torch.load("trained_25e.pth"))

In [None]:

"""
>>>>> COMPUTE VALIDATION DICE SCORE <<<<<
(should be almost identical to the one stated in best_dice.txt)
"""
from torch.utils.data import DataLoader

data_path = './data/val'
num_workers = 0
batch_size = 2
val_set = TrainDataset(data_path)
val_data_loader = DataLoader(dataset=val_set, num_workers=num_workers, batch_size=batch_size)

total_dice = 0.0

model.eval() # switch model to evaluation mode
model.to(device)
with torch.no_grad():
    dices = []
    # Fetch images and labels.
    for iteration, sample in enumerate(val_data_loader):
        img, mask = sample 

        # forward
        img = img.unsqueeze(1)
        img = img.to(device)
        outputs = model(img)

        mask = mask.type(torch.LongTensor)
        
        # convert output to predicted class so it can be visualised
        mask_pred = torch.argmax(outputs, dim=1).detach().cpu()
        
        for i in range(batch_size):
            print('+++++')
            show_image_mask(mask_pred[i,...].squeeze(), mask[i,...].squeeze())
            dice = average_dice(mask[i,...].numpy(), mask_pred[i,...].numpy(),True)
            dices.append(dice)
            print('avg dice=', dice)
        
    avg = sum(dices) / len(dices)
    print("<<<<<>>>>>")
    print('Done val! DICE: ', avg) # 20 images in val dataset

In [None]:
"""
>>>>> CREATE MASKS <<<<<
"""

from torch.utils.data import DataLoader

# In this block you are expected to write code to load saved model and deploy it to all data in test set to 
# produce segmentation masks in png images valued 0,1,2,3, which will be used for the submission to Kaggle.
data_path = './data/test'
num_workers = 0
batch_size = 1

test_set = TestDataset(data_path)
test_data_loader = DataLoader(dataset=test_set, num_workers=num_workers,batch_size=batch_size, shuffle=False)

model.cpu()
model.eval() # switch model to evaluation mode

with torch.no_grad():
    
    # Fetch images and labels.
    for iteration, sample in enumerate(test_data_loader):
        img = sample

        # forward
        img = img.unsqueeze(1)
        outputs = model(img)

        # convert output to predicted class so it can be visualised
        pred_class = torch.argmax(outputs, dim=1)
        #show_image_mask(img[0,...].squeeze(), pred_class[0,...].squeeze())
        
        # save predictions
        mask_filename = 'cmr{0}_mask.png'.format(iteration + 121)
        print('saving as', mask_filename)
        
        pred_img = pred_class[0,...].squeeze().numpy()
        cv2.imwrite(os.path.join('./data/test/mask', mask_filename), pred_img)
        
print('Done test!')


In [None]:
"""
>>>>> DISPLAY MASKS <<<<<
"""
from torch.utils.data import DataLoader

data_path = './data/test'
num_workers = 0
batch_size = 1
val_set = TrainDataset(data_path)
val_data_loader = DataLoader(dataset=val_set, num_workers=num_workers, batch_size=batch_size)

total_dice = 0.0

model.eval() # switch model to evaluation mode

with torch.no_grad():
    
    # Fetch images and labels.
    for iteration, sample in enumerate(val_data_loader):
        img = sample[0].squeeze()
        mask = sample[1].squeeze()
        show_image_mask(img,mask)

In [None]:
"""
>>>>> CREATE SUBMISSION <<<<<
"""
import numpy as np
import os
import cv2

def rle_encoding(x):
    '''
    *** Credit to https://www.kaggle.com/rakhlin/fast-run-length-encoding-python ***
    x: numpy array of shape (height, width), 1 - mask, 0 - background
    Returns run length as list
    '''
    dots = np.where(x.T.flatten() == 1)[0]
    run_lengths = []
    prev = -2
    for b in dots:
        if (b > prev + 1): run_lengths.extend((b + 1, 0))
        run_lengths[-1] += 1
        prev = b
    return run_lengths


def submission_converter(mask_directory, path_to_save):
    writer = open(os.path.join(path_to_save, "submission.csv"), 'w')
    writer.write('id,encoding\n')

    files = os.listdir(mask_directory)

    for file in files:
        name = file[:-4]
        mask = cv2.imread(os.path.join(mask_directory, file), cv2.IMREAD_UNCHANGED)

        mask1 = (mask == 1)
        mask2 = (mask == 2)
        mask3 = (mask == 3)

        encoded_mask1 = rle_encoding(mask1)
        encoded_mask1 = ' '.join(str(e) for e in encoded_mask1)
        encoded_mask2 = rle_encoding(mask2)
        encoded_mask2 = ' '.join(str(e) for e in encoded_mask2)
        encoded_mask3 = rle_encoding(mask3)
        encoded_mask3 = ' '.join(str(e) for e in encoded_mask3)

        writer.write(name + '1,' + encoded_mask1 + "\n")
        writer.write(name + '2,' + encoded_mask2 + "\n")
        writer.write(name + '3,' + encoded_mask3 + "\n")

    writer.close()
    
submission_converter('./data/test/mask', './')
print('Submission done!')