In [1]:
import math 
import torch.nn.functional as F
import numpy as np
from torch import nn
import torch
from torch.utils.data import Dataset, DataLoader, Subset
from PIL import Image
from sklearn.model_selection import train_test_split
from skimage import io, transform
from torchvision import transforms, utils, datasets
import numpy as np
import os 
import conv

In [3]:
def data_preprocess(images_path, masks_path):

    _, _, image_files = next(os.walk(images_path))

    _, _, mask_files = next(os.walk(masks_path))

    
    for i in image_files:
        if i not in mask_files:
            image_files.remove(i)
            
    for i in mask_files:
        if i not in image_files:
            mask_files.remove(i)


    for idx, i in enumerate(image_files):
        image_files[idx] = os.path.abspath(os.path.join(images_path, i))
    
    for idx, i in enumerate(mask_files):
        mask_files[idx] = os.path.abspath(os.path.join(masks_path, i))


    return image_files, mask_files

In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv = Conv2dSamePadding(3, 32, 3)
        
        self.conv_last = Conv2dSamePadding(256, 16, 3)
        
        self.block1 = nn.Sequential(
            Conv2dSamePadding(32, 32, 3),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )
        
        self.block2 = nn.Sequential(
            Conv2dSamePadding(32, 64, 3),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.block3 = nn.Sequential(
            Conv2dSamePadding(64, 128, 3),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        
        self.block4 = nn.Sequential(
            Conv2dSamePadding(128, 256, 3),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        
        self.pool = nn.MaxPool2d(2, 2)
            
    def forward(self, x):
        x = self.conv(x)
        residual1 = x
        x = self.block1(x)
        x += residual1
        x = self.pool(x)
        
        x = self.block2(x)
        x = self.pool(x)
        
        x = self.block3(x)       
        x = self.pool(x)

        x = self.block4(x)      
        x = self.pool(x)
        
        x = self.conv_last(x)
        return x 


In [5]:
class LSTM_CNN(nn.Module):
    def __init__(self, batch_size):
        super(LSTM_CNN, self).__init__()
        self.batch_size = batch_size
        self.lstm = nn.LSTM(input_size = (256*256), hidden_size = 256, num_layers = 2)
        self.conv1 = Conv2dSamePadding(3, 64, 3)
        self.conv2 = Conv2dSamePadding(64,16,3)
        

    def forward(self, x):
        h0 = torch.zeros(2, x.size(0), 64)
        c0 = torch.zeros(2, x.size(0), 64)
        x = x.view(self.batch_size,3,256*256)
        out, _ = self.lstm(x)
        out = out.view(self.batch_size,3,16,16)
        
        out = self.conv1(out)
        out = self.conv2(out)
        return out

# model = LSTM_CNN(1)
# model(torch.randn(1, 3, 256*256))

In [6]:
class Decoder(nn.Module):
    def __init__(self, encoder):
        super(Decoder, self).__init__()

        self.encoder = Encoder()
#         self.l = LSTM_CNN(16)
        self.conv1 = Conv2dSamePadding(16, 64, 3)
        
        self.conv2 = Conv2dSamePadding(64, 16, 3)
        
        self.conv_out = Conv2dSamePadding(16, 1, 3)
        
        self.tconv1 = nn.ConvTranspose2d(16, 64, kernel_size=6, stride = 4, padding = 1)
        
        self.tconv2 = nn.ConvTranspose2d(64, 16, kernel_size = 8, stride = 4, padding = 2)


    def forward(self, x2):
        
#         x1 = self.l(x1)
        x = self.encoder(x2)
        
#         x = F.interpolate(x1, size=(64,64))
#         x = self.conv1(x)
        
#         x = F.interpolate(x, size=(256,256))
#         x = self.conv2(x)

#         x = torch.cat((x1, x2))
        
        x = self.tconv1(x)
        x = self.tconv2(x)
        
        x   = self.conv_out(x) # -> (b, 1, 256, 256)
        
#         x = nn.Sigmoid()(x)
        
        return x   

In [9]:
def train_val_dataset(dataset, val_split = None):
    train_idx, val_idx = train_test_split(list(range(len(dataset))), test_size=val_split)
    train_data = Subset(dataset, train_idx)
    val_data = Subset(dataset, val_idx)
    return train_data, val_data


class CustomDataset(Dataset):
    def __init__(self, images_path, masks_path):
        self.images, self.masks = data_preprocess(images_path, masks_path)
    
    def __len__(self):
        return len(self.images)

    def transform(self, image, mask, output_size):

        # Resize
        nheight, nwidth = output_size

        if image.shape[2] > 3:
            image = image[:,:,:3]

        image = transform.resize(image, (nheight, nwidth))

        if len(mask.shape) > 2:
            mask = mask[:,:,0]

        mask = transform.resize(mask, (nheight, nwidth))

        norm_img = transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
        norm_mask = transforms.Normalize([0.5],[0.5])
        
        image = image.astype(np.float32)
        mask = mask.astype(np.float32)
        image = transforms.functional.to_tensor(image)
        mask = transforms.functional.to_tensor(mask)

        return image, mask

    def __getitem__(self,idx):

        img_name = self.images[idx]
        mask_name = self.masks[idx]

        image = io.imread(img_name)
        mask = io.imread(mask_name)
        x, y = self.transform(image, mask, [256,256])
        return x, y


In [None]:
learning_rate = 0.00001
batch_size = 16

data = CustomDataset("fake/", "masks/")

train_set, val_set = train_val_dataset(data, val_split = 0.20)

train_loader = DataLoader(dataset = train_set, shuffle = True, batch_size = batch_size)
val_loader = DataLoader(dataset = val_set, shuffle = False, batch_size = batch_size)

# model1 = LSTM_CNN(batch_size)
model2 = Encoder()

model = Decoder(model2)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
model.train()
print("Training")

for epoch in range(10):
    for i, data in enumerate(train_loader):

        image, mask = data

        optimizer.zero_grad()
        
        output = model(image)
        loss = criterion(output, mask)
        loss.backward()
        optimizer.step()


        if i % 22 == 21:
            print('[%d, %5d] loss: %.3f' %
                    (epoch + 1, i + 1, loss.item()))

print("Finished Training")


Training
[1,    22] loss: 0.663
[2,    22] loss: 0.662


In [None]:
model.eval()
test_image = 0
y_test=0

with torch.no_grad():
    for t in val_loader:
        
        test_image, y_test = t
    y_pred = model(test_image)
    print("1")
    im = transforms.ToPILImage()(test_image[5]).convert("RGB")
    display(im)
    print("y_test", y_test[5].shape)
    m_test = transforms.ToPILImage()(y_test[5])
    display(m_test)
    print("y_pred", y_pred[6].shape)
    mask_pred = transforms.ToPILImage()(y_pred[5])
    display(mask_pred)
