In [1]:
import os

input_dir = 'data/images'
target_dir = 'data/annotations/trimaps'
img_size = (160, 160)
num_classes = 3
batch_size = 32

input_img_paths = sorted([os.path.join(input_dir, fname) 
                            for fname in os.listdir(input_dir)
                            if fname.endswith('.jpg')])
target_img_paths = sorted([os.path.join(target_dir, fname) 
                            for fname in os.listdir(target_dir)
                            if fname.endswith('.png') and not fname.startswith('.')])


In [None]:
from IPython.display import display
import PIL
from PIL import Image

display(Image.open(input_img_paths[7]))
display(PIL.ImageOps.autocontrast(Image.open(target_img_paths[7])))
# display(Image.open(target_img_paths[7]))

In [None]:
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
import torch

class OxfordPetDataset(Dataset):
    def __init__(self, img_size, input_img_paths, target_img_paths, transform=None, target_transform=None):
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.input_img_paths)

    def __getitem__(self, index):
        image = Image.open(input_img_paths[index]).convert("RGB")
        target = np.array(Image.open(target_img_paths[index]))
        target = target - np.ones_like(target)
        target = torch.from_numpy(target).unsqueeze(0).to(dtype=torch.int64)

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)

        return image, target
transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((160, 160))])
target_transform = transforms.Compose([transforms.Resize((160, 160))])

dataset = OxfordPetDataset(img_size=(160, 160), input_img_paths=input_img_paths, target_img_paths=target_img_paths, transform=transform, target_transform=target_transform)
print(dataset.__len__())
image, target = next(iter(dataset))
print(image ,target)
print(torch.unique(target))

In [4]:
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split
from torchvision import transforms

class OxfordPetDataModule(pl.LightningDataModule):

    def __init__(self, img_size, batch_size, input_img_paths, target_img_paths):
        self.img_size = img_size
        self.batch_size = batch_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths

    def setup(self, stage=None):
        transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((160, 160))])
        target_transform = transforms.Compose([transforms.Resize((160, 160))])
        dataset = OxfordPetDataset(img_size=self.img_size, input_img_paths=self.input_img_paths, target_img_paths=self.target_img_paths, transform=transform, target_transform=target_transform)
        self.train_dataset, self.valid_dataset = random_split(dataset, [7000, 390])

    def train_dataloader(self):
        oxford_train = DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=8, shuffle=True)
        return oxford_train

    def val_dataloader(self):
        oxford_valid = DataLoader(self.valid_dataset, batch_size=self.batch_size, num_workers=8, shuffle=False)
        return oxford_valid

img_size = (160, 160)
batch_size = 64
dm = OxfordPetDataModule(img_size, batch_size, input_img_paths, target_img_paths)
dm.setup()
dataloader = dm.train_dataloader()
# for x, y in dataloader:
#     print(x.shape, y.shape)



In [5]:
from pytorch_lightning import LightningModule
import torch.nn as nn
import torch.nn.functional as F

def conv_block(in_channels, out_channels):
    out = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(num_features=out_channels),
        nn.ReLU()
    )
    return out

def conv_block_2x(in_channels, out_channels):
    out = nn.Sequential(
        conv_block(in_channels, out_channels),
        conv_block(out_channels, out_channels)
    )
    return out

def downsample():
    return nn.MaxPool2d(kernel_size=2, stride=2)

def upsample(in_channels, out_channels):
    return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)



class UNet(pl.LightningModule):
    def __init__(self, input_dim, num_classes, num_channels):
        super().__init__()
        self.save_hyperparameters()
        # conv_block_2x
        self.encoder_1 = conv_block_2x(input_dim, num_channels)
        # downsample
        self.down_sample_1 = downsample()
        # conv_block_2x
        self.encoder_2 = conv_block_2x(num_channels, num_channels * 2)
        # downsample
        self.down_sample_2 = downsample()
        # conv_block_2x
        self.encoder_3 = conv_block_2x(num_channels * 2, num_channels * 4)
        # downsample
        self.down_sample_3 = downsample()
        # conv_block_2x
        self.encoder_4 = conv_block_2x(num_channels * 4, num_channels * 8)
        # downsample
        self.down_sample_4 = downsample()
        # bridge
        self.bridge = conv_block_2x(num_channels * 8, num_channels * 16)
        # upsample
        self.up_sample_1 = upsample(num_channels * 16, num_channels * 8)
        # conv_block_2x
        self.decoder_1 = conv_block_2x(num_channels * 16, num_channels * 8)
        # upsample
        self.up_sample_2 = upsample(num_channels * 8, num_channels * 4)
        # conv_block_2x
        self.decoder_2 = conv_block_2x(num_channels * 8, num_channels * 4)
        # upsample
        self.up_sample_3 = upsample(num_channels * 4, num_channels * 2)
        # conv_block_2x
        self.decoder_3 = conv_block_2x(num_channels * 4, num_channels * 2)
        # upsample
        self.up_sample_4 = upsample(num_channels * 2, num_channels)
        # conv_block_2x
        self.decoder_4 = conv_block_2x(num_channels * 2, num_channels)
        # 1X1 conv
        self.classifier = nn.Sequential(
            nn.Conv2d(num_channels, num_classes, kernel_size=1, stride=1),
            nn.Tanh()
        )
    
    def forward(self, input):
        encoder_feat_1 = self.encoder_1(input)
        down_feat_1 = self.down_sample_1(encoder_feat_1)
        encoder_feat_2 = self.encoder_2(down_feat_1)
        down_feat_2 = self.down_sample_2(encoder_feat_2)
        encoder_feat_3 = self.encoder_3(down_feat_2)
        down_feat_3 = self.down_sample_3(encoder_feat_3)
        encoder_feat_4 = self.encoder_4(down_feat_3)
        down_feat_4 = self.down_sample_4(encoder_feat_4)

        bridge_feat = self.bridge(down_feat_4)

        up_feat_1 = self.up_sample_1(bridge_feat)
        decoder_feat_1 = self.decoder_1(torch.cat((encoder_feat_4, up_feat_1), dim=1))
        up_feat_2 = self.up_sample_2(decoder_feat_1)
        decoder_feat_2 = self.decoder_2(torch.cat((encoder_feat_3, up_feat_2), dim=1))
        up_feat_3 = self.up_sample_3(decoder_feat_2)
        decoder_feat_3 = self.decoder_3(torch.cat((encoder_feat_2, up_feat_3), dim=1))
        up_feat_4 = self.up_sample_4(decoder_feat_3)
        decoder_feat_4 = self.decoder_4(torch.cat((encoder_feat_1, up_feat_4), dim=1))

        out = self.classifier(decoder_feat_4)

        return out


    def training_step(self, batch):
        x, y = batch
        pred = self(x)
        loss = F.cross_entropy(pred, y.squeeze())
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        pred = self(x)
        loss = F.cross_entropy(pred, y.squeeze())
        self.log('val_loss', loss)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=2e-4)
        return optimizer
        

# unet = UNet(input_dim=3, num_classes=3, num_channels=64)
# for x, y in dataloader:
#     print(x.shape, y.shape)
#     print(unet(x).shape)

In [6]:
from gc import callbacks
from pytorch_lightning.callbacks import ModelCheckpoint

model = UNet(input_dim=3, num_classes=3, num_channels=64)
data_module = OxfordPetDataModule(img_size, batch_size, input_img_paths, target_img_paths)
ckpt_callback = ModelCheckpoint(save_top_k=2, monitor='val_loss', mode='min')
trainer = pl.Trainer(gpus=1, max_epochs=10, callbacks=[ckpt_callback])
trainer.fit(model, data_module)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

   | Name          | Type            | Params
---------------------------------------------------
0  | encoder_1     | Sequential      | 39.0 K
1  | down_sample_1 | MaxPool2d       | 0     
2  | encoder_2     | Sequential      | 221 K 
3  | down_sample_2 | MaxPool2d       | 0     
4  | encoder_3     | Sequential      | 886 K 
5  | down_sample_3 | MaxPool2d       | 0     
6  | encoder_4     | Sequential      | 3.5 M 
7  | down_sample_4 | MaxPool2d       | 0     
8  | bridge        | Sequential      | 14.2 M
9  | up_sample_1   | ConvTranspose2d | 2.1 M 
10 | decoder_1     | Sequential      | 7.1 M 
11 | up_sample_2   | ConvTranspose2d | 524 K 
12 | decoder_2     | Sequential      | 1.8 M 
13 | up_sample_3   | ConvTranspose2d | 131 K 
14 | decoder_3     | Sequential      | 443 K 
15 | up_sample_4   | ConvTranspose2d | 32.8 K
16 | decoder

Epoch 9: 100%|██████████| 117/117 [01:31<00:00,  1.28it/s, loss=0.387, v_num=1]


In [None]:
test_dir = './test'

test_img_paths = sorted([os.path.join(test_dir, fname)
                            for fname in os.listdir(test_dir)
                            if fname.endswith('.jpg')])
print(test_img_paths)

model = UNet.load_from_checkpoint('./lightning_logs/version_1/checkpoints/epoch=9-step=1099.ckpt')

for img in test_img_paths:                            
    image = Image.open(img).convert("RGB")
    display(image)
    image_size = image.size
    transform = transforms.Compose([transforms.ToTensor(), transforms.Resize((160, 160))])
    image = transform(image).unsqueeze(0)

    transform = transforms.Compose([transforms.ToPILImage(), transforms.Resize((image_size[1], image_size[0]))])
    pred_mask = transform(torch.argmax(model(image), dim=1).to(dtype=torch.float))
    display(PIL.ImageOps.autocontrast(pred_mask))