In [1]:
import albumentations as A
import cv2
import numpy as np
from albumentations.pytorch import ToTensorV2

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import tqdm
from torch import optim
from torch.utils.data import DataLoader

from common import get_logger
from dataset import DetectionDataset
from models import get_model
from transforms import get_train_transforms

In [3]:
class Compose(object):

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, image, mask):
        for t in self.transforms:
            image, mask = t(image, mask)
        return image, mask


class Resize(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, image, mask, force_apply=False):
        image_, mask_ = image.copy(), mask.copy()
        if image_.shape[0] != self.size[1] or image_.shape[1] != self.size[0]:
            image_ = cv2.resize(image_, self.size)
            mask_ = cv2.resize(mask_, self.size)
        return dict(image=image_, mask=mask_)


class Normalize(object):
    def __init__(self, mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25)):
        self.mean = np.asarray(mean).reshape((1, 1, 3)).astype(np.float32)
        self.std = np.asarray(std).reshape((1, 1, 3)).astype(np.float32)

    def __call__(self, image, mask):
        image = (image - self.mean) / self.std
        return image, mask


# TODO TIP: Is default image size (256) enough for segmentation of car license plates?
def get_train_transforms(image_size):
    return A.Compose([
        A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.25, 0.25, 0.25), max_pixel_value=1.),
        Resize(size=(image_size, image_size)),
        ToTensorV2(),
    ])

In [4]:
data_path = "/home/ubuntu/datasets/segment_car_plate/data/"
epochs = 8
batch_size = 32
image_size = 256
output_dir = "out"
lr = 3e-4
load = None


In [5]:
def train(model, optimizer, criterion, train_dataloader, logger, device=None):
    model.train()

    epoch_losses = []

    tqdm_iter = tqdm.tqdm(enumerate(train_dataloader), total=len(train_dataloader))
    for i, batch in tqdm_iter:
        imgs, true_masks = batch["image"], batch["mask"].float()
        masks_pred = model(imgs.to(device)).float()
        masks_probs = torch.sigmoid(masks_pred).to(device)

        loss = criterion(masks_probs.view(-1), true_masks.view(-1).to(device)).cpu()
        epoch_losses.append(loss.item())
        tqdm_iter.set_description(f"mean loss: {np.mean(epoch_losses):.4f}")

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    logger.info(f"Epoch finished! Loss: {np.mean(epoch_losses):.5f}")

    return np.mean(epoch_losses)

In [6]:
os.makedirs(output_dir, exist_ok=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

logger = get_logger(os.path.join(output_dir, "train.log"))
logger.info("Start training with params:")

model = get_model()
if load is not None:
    with open(args.load, "rb") as fp:
        state_dict = torch.load(fp, map_location="cpu")
    model.load_state_dict(state_dict)
model.to(device)
logger.info(f"Model type: {model.__class__.__name__}")

optimizer = optim.Adam(model.parameters(), lr=lr)

criterion = nn.BCELoss()

train_transforms = get_train_transforms(image_size)
train_dataset = DetectionDataset(data_path, os.path.join(data_path, "train_segmentation.json"),
                                 transforms=train_transforms, split="train")
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4,
                              pin_memory=True, shuffle=True, drop_last=True)


logger.info(f"Length of train = {len(train_dataset)}")
best_model_info = {"epoch": -1, "train_loss": np.inf}

2022-07-25 20:27:14 Start training with params:
2022-07-25 20:27:15 Model type: Unet
2022-07-25 20:27:15 Length of train = 20505


In [None]:
for epoch in range(epochs):
    logger.info(f"Starting epoch {epoch + 1}/{epochs}.")

    train_loss = train(model, optimizer, criterion, train_dataloader, logger, device)
    if train_loss < best_model_info["train_loss"]:
        with open(os.path.join(output_dir, "CP-best.pth"), "wb") as fp:
            torch.save(model.state_dict(), fp)
        logger.info(f"Train loss: {train_loss:.3f} (best)")
    else:
        logger.info(f"Train loss: {train_loss:.5f} (best {best_model_info['train_loss']:.5f})")

with open(os.path.join(output_dir, "CP-last.pth"), "wb") as fp:
    torch.save(model.state_dict(), fp)

2022-07-25 20:27:18 Starting epoch 1/8.
mean loss: 0.0630: 100%|████████████████████████████████████████████████| 640/640 [43:41<00:00,  4.10s/it]
2022-07-25 21:11:00 Epoch finished! Loss: 0.06301
2022-07-25 21:11:01 Train loss: 0.063 (best)
2022-07-25 21:11:01 Starting epoch 2/8.
mean loss: 0.0072: 100%|████████████████████████████████████████████████| 640/640 [40:52<00:00,  3.83s/it]
2022-07-25 21:51:54 Epoch finished! Loss: 0.00718
2022-07-25 21:51:54 Train loss: 0.007 (best)
2022-07-25 21:51:54 Starting epoch 3/8.
mean loss: 0.0049: 100%|████████████████████████████████████████████████| 640/640 [40:19<00:00,  3.78s/it]
2022-07-25 22:32:15 Epoch finished! Loss: 0.00488
2022-07-25 22:32:15 Train loss: 0.005 (best)
2022-07-25 22:32:15 Starting epoch 4/8.
mean loss: 0.0037: 100%|████████████████████████████████████████████████| 640/640 [40:45<00:00,  3.82s/it]
2022-07-25 23:13:00 Epoch finished! Loss: 0.00374
2022-07-25 23:13:01 Train loss: 0.004 (best)
2022-07-25 23:13:01 Starting epo