In [1]:
from pathlib import Path
from course_ocr_t1.data import MidvPackage
from tqdm import tqdm
from matplotlib import pyplot as plt
from matplotlib import patches
import numpy as np

In [2]:
import cv2
from  torch.utils.data import Dataset, DataLoader

In [3]:
import torch
import torch.nn as nn
from tqdm.notebook import tqdm

In [4]:
DATASET_PATH = Path() / '..'/ '..' / 'data' / 'midv500_compressed'
assert DATASET_PATH.exists(), DATASET_PATH.absolute()

In [5]:
data_packs = MidvPackage.read_midv500_dataset(DATASET_PATH)
len(data_packs), type(data_packs[0])

(50, course_ocr_t1.data.MidvPackage)

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [7]:
class task1_dataset(Dataset):
    def __init__(self, data_packs, device, split = 'train'):
        assert split in ['train', 'test']
        self.data_packs = data_packs
        self.indices = []
        self.device = device
        
        if split == 'train':
            for i, data_pack in enumerate(data_packs):
                for j in range(len(data_pack)):
                    if not data_pack[j].is_test_split():
                        self.indices.append((i, j))
        else:
            for i, data_pack in enumerate(data_packs):
                for j in range(len(data_pack)):
                    if data_pack[j].is_test_split():
                        self.indices.append((i, j))
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        i, j = self.indices[idx]
        dp = self.data_packs[i][j]
        image = np.array(dp.image.convert('RGB')) / 255.
        mask = cv2.fillConvexPoly(np.zeros(image.shape[:2]), np.array(dp.gt_data['quad']), (1,))[np.newaxis, ...]
        return torch.tensor(image.transpose(2, 0, 1), dtype=torch.float, deivce=device),
                torch.tensor(mask, dtype=torch.float, device=device)

In [8]:
train_data = task1_dataset(data_packs, 'train') # TODO: transform
val_data = task1_dataset(data_packs, 'test')

In [9]:
batch_size=4
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_data, batch_size=batch_size, shuffle=True)

In [10]:
class enc_conv_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.block(x)

In [11]:
class upsample_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.dec_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, u, e):
        u = self.upsample(u)
        pad_w = e.shape[2] - u.shape[2]
        pad_h = e.shape[3] - u.shape[3]
        padding = [pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2]
#         print(e.shape, u.shape)
#         print(padding)
        u = nn.functional.pad(u, padding)
        return self.dec_conv(torch.cat((e, u), dim=1))

In [12]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        # encoder
        self.enc_conv0 = enc_conv_block(3, 64)
        self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc_conv1 = enc_conv_block(64, 128)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc_conv2 = enc_conv_block(128, 256)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.enc_conv3 = enc_conv_block(256, 512)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # bottleneck
        self.bottleneck_conv = enc_conv_block(512, 1024)

        # decode
        self.up_0 = upsample_block(1024+512, 512)
        self.up_1 = upsample_block(512+256, 256)
        self.up_2 = upsample_block(256+128, 128)
        self.up_3 = upsample_block(128+64, 64)
        
        self.out = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=3, padding=1),
            nn.BatchNorm2d(1),
        )

    def forward(self, x):
        # encoder
        e0 = self.enc_conv0(x)
        e1 = self.enc_conv1(self.pool0(e0))
        e2 = self.enc_conv2(self.pool1(e1))
        e3 = self.enc_conv3(self.pool2(e2))

        # bottleneck
        b = self.bottleneck_conv(self.pool3(e3)) 

        # decoder
        u0 = self.up_0(b, e3)
        u1 = self.up_1(u0, e2)
        u2 = self.up_2(u1, e1)
        u3 = self.up_3(u2, e0)
        
        out = self.out(u3)
        return out
unet_model = UNet().to(device)

In [13]:
from time import time

In [16]:
def train(model, opt, loss_fn, epochs, data_tr, data_val):
    X_val, Y_val = next(iter(data_val))
    train_loss, val_loss, val_score = [], [], []

    for epoch in range(epochs):
        tic = time()
        print('* Epoch %d/%d' % (epoch+1, epochs))

        avg_loss = 0
        model.train()
        for X_batch, Y_batch in tqdm(data_tr):
#             X_batch = X_batch.to(device)
#             Y_batch = Y_batch.to(device)

            opt.zero_grad()
            Y_pred = model(X_batch)
            
            loss = loss_fn(Y_batch, Y_pred)
            loss.backward()
            opt.step()
            avg_loss += loss / len(data_tr)
        toc = time()
        print('loss: %f' % avg_loss)
        train_loss.append(avg_loss)
        
        model.eval()
        Y_hat = model(X_val.to(device)).detach().cpu()

        val_loss_sum = 0
        val_score_sum = 0
        for X_val_batch, Y_val_batch in data_val:
            X_val_batch = X_val_batch.to(device)
            Y_val_batch = Y_val_batch.to(device)
            with torch.set_grad_enabled(False):
                Y_pred_batch = model(X_val_batch)
                loss = loss_fn(Y_val_batch, Y_pred_batch)
                prediction = torch.sigmoid(Y_pred_batch) > 0.5
            val_loss_sum += loss
            val_score_sum += iou_pytorch(prediction, Y_val_batch).mean().item()
        processed_size = len(data_val)
        val_loss.append(val_loss_sum/processed_size)
        val_score.append(val_score_sum/processed_size)

        clear_output(wait=True)
        plt.figure(figsize = (16, 7))
        plt.subplot(1, 2, 1)
        plt.title('Loss')
        plt.plot(train_loss, label='train loss')
        plt.plot(val_loss, label='val loss')
        plt.legend()
        plt.subplot(1, 2, 2)
        plt.title('Score')
        plt.plot(val_score)
        plt.show()

        clear_output(wait=True)

        for k in range(6):
            plt.subplot(2, 6, k+1)
            plt.imshow(np.rollaxis(X_val[k].numpy(), 0, 3), cmap='gray')
            plt.title('Real')
            plt.axis('off')

            plt.subplot(2, 6, k+7)
            plt.imshow(Y_hat[k, 0], cmap='gray')
            plt.title('Output')
            plt.axis('off')
        plt.suptitle('%d / %d - loss: %f' % (epoch+1, epochs, avg_loss))
        plt.show()

    return train_loss, val_loss, val_score

In [None]:
unet_train_loss, unet_val_loss, unet_val_score = train(unet_model, torch.optim.Adam(unet_model.parameters()), nn.BCELoss(), 100, train_dl, val_dl)

* Epoch 1/100
