In [1]:
# !gdown --id 1eAymcrGvjlnGt3amBxXPvmKJODtE34ZD

In [2]:
# !unzip data.zip

### Import Modules

In [6]:
import time
import os
import torch
import scipy.io as sio
import numpy as np
import torch.utils.data as data
import random
import cv2
import copy
from collections import namedtuple

import torch.nn as nn
import torch.nn.init as init
from torchvision import models
from torchvision.models.vgg import model_urls
import torch.nn.functional as F
import torch.optim as optim

### Utils

In [7]:
if not os.path.isdir('store'):
  os.mkdir('store')

def gauss_normal_generate(d):
    # generate normal gauss map
    width = d
    height = d
    center_x = width / 2
    center_y = height / 2
    # sigma principle to make the number at the edge of the circle to be very small
    sigma = d / 3
    Gauss_map = np.zeros((height, width))
    for i in range(height):
        for j in range(width):
            dis = (i - center_y) ** 2 + (j - center_x) ** 2
            if dis > (d ** 2) / 4:
                value = 0
            else:
                value = np.exp(-0.5 * dis / sigma ** 2)
            Gauss_map[i, j] = value
    return Gauss_map


def cvt2HeatmapImg(img):
    # display
    img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
    img = cv2.applyColorMap(img, cv2.COLORMAP_JET)
    return img


def cvt2HeatmapMatrix(img):
    # calculate
    img = (np.clip(img, 0, 1) * 255).astype(np.uint8)
    return img


def point_generate(x1, y1, x2, y2):
    x = []
    y = []
    mid_x1 = int((x1[0] + x1[2]) / 2)
    mid_y1 = int((y1[0] + y1[2]) / 2)
    mid_x2 = int((x2[0] + x2[2]) / 2)
    mid_y2 = int((y2[0] + y2[2]) / 2)
    x.append(copy.deepcopy(int((x1[0] + x1[1] + mid_x1) / 3)))
    x.append(copy.deepcopy(int((x2[0] + x2[1] + mid_x2) / 3)))
    x.append(copy.deepcopy(int((x2[2] + x2[3] + mid_x2) / 3)))
    x.append(copy.deepcopy(int((x1[2] + x1[3] + mid_x1) / 3)))
    y.append(copy.deepcopy(int((y1[0] + y1[1] + mid_y1) / 3)))
    y.append(copy.deepcopy(int((y2[0] + y2[1] + mid_y2) / 3)))
    y.append(copy.deepcopy(int((y2[2] + y2[3] + mid_y2) / 3)))
    y.append(copy.deepcopy(int((y2[2] + y2[3] + mid_y2) / 3)))
    return x, y


def interval_list_generate(text):
    word_list = []
    for part in text:
        part_word_list = part.strip().replace(' ', '\n').split('\n')
        for i in range(len(part_word_list) - 1, -1, -1):
            if part_word_list[i] == '':
                part_word_list.remove('')
        word_list += part_word_list
    interval_i = 0
    interval_list = []
    for word in word_list:
        interval_i += len(word)
        interval_list.append(copy.deepcopy(interval_i))
    interval_list = interval_list[:-1]
    return interval_list


class averager(object):
    def __init__(self):
        self.reset()

    def add(self, v):
        count = v.numel()
        v = v.sum()
        self.n_count += count
        self.sum += v

    def reset(self):
        self.n_count = 0
        self.sum = 0

    def val(self):
        res = 0
        if self.n_count != 0:
            res = self.sum / float(self.n_count)
        return res


### Dataset

In [8]:
class SynthText(object):
    def __init__(self):
        self.generate_information()

    def generate_information(self):
        self.data = sio.loadmat('./data/SynthText/gt.mat')
        char_BB = self.data['charBB']
        self.cor_list = char_BB[0]
        img_txt = self.data['txt']
        self.text = img_txt[0]
        names = self.data['imnames']
        # the third 0 to get the string in list
        self.name = names[0]
        self.gauss_map = gauss_normal_generate(20)

    def len(self):
        return len(self.data['charBB'][0])

    def im_read_resize(self, path):
        img = cv2.imread(path)
        img_size = (img.shape[0], img.shape[1])
        if img_size[0] > img_size[1]:
            img = np.rot90(img, -1)
        resized_img = cv2.resize(img, (600, 400), cv2.INTER_NEAREST)
        return resized_img, img_size

    def char_label_generate(self, gauss_map, img_size, cor_list):
        # generate the first map with all char box being replaced with gauss map
        h = img_size[0]
        w = img_size[1]
        char_label = np.zeros((h, w))
        char_number = cor_list.shape[2]
        for i in range(char_number):
            x = []
            y = []
            for index in range(4):
                x.append(copy.deepcopy(int(cor_list[0][index][i])))
                y.append(copy.deepcopy(int(cor_list[1][index][i])))
            x_min = max(min(x), 0)
            x_max = min(max(x), w)
            y_min = max(min(y), 0)
            y_max = min(max(y), h)
            point1 = np.array([[0, 0], [19, 0], [19, 19], [0, 19]], dtype='float32')
            point2 = np.array([[x[0] - x_min, y[0] - y_min], [x[1] - x_min, y[1] - y_min],
                               [x[2] - x_min, y[2] - y_min], [x[3] - x_min, y[3] - y_min]], dtype='float32')
            w_final = x_max - x_min
            h_final = y_max - y_min
            m = cv2.getPerspectiveTransform(point1, point2)
            target = cv2.warpPerspective(gauss_map, m, (w_final, h_final), cv2.INTER_NEAREST)
            for j in range(y_min, y_max):
                for k in range(x_min, x_max):
                    if target[j - y_min][k - x_min] > char_label[j][k]:
                        char_label[j, k] = target[j - y_min][k - x_min]
        if h > w:
            char_label = np.rot90(char_label, -1)
        char_label = cv2.resize(char_label, (300, 200), cv2.INTER_NEAREST)
        char_label = cvt2HeatmapMatrix(char_label)
        return char_label

    def interval_label_generate(self, gauss_map, img_size, cor_list, interval_list):
        # generate the first map with all char box being replaced with gauss map
        h = img_size[0]
        w = img_size[1]
        interval_label = np.zeros((h, w))
        char_number = cor_list.shape[2]
        for i in range(char_number - 1):
            if i + 1 in interval_list:
                continue
            x1 = []
            y1 = []
            x2 = []
            y2 = []
            for index in range(4):
                x1.append(copy.deepcopy(int(cor_list[0][index][i])))
                y1.append(copy.deepcopy(int(cor_list[1][index][i])))
                x2.append(copy.deepcopy(int(cor_list[0][index][i + 1])))
                y2.append(copy.deepcopy(int(cor_list[1][index][i + 1])))
            x, y = point_generate(x1, y1, x2, y2)
            x_min = max(min(x), 0)
            x_max = min(max(x), w)
            y_min = max(min(y), 0)
            y_max = min(max(y), h)
            point1 = np.array([[0, 0], [19, 0], [19, 19], [0, 19]], dtype='float32')
            point2 = np.array([[x[0] - x_min, y[0] - y_min], [x[1] - x_min, y[1] - y_min],
                               [x[2] - x_min, y[2] - y_min], [x[3] - x_min, y[3] - y_min]], dtype='float32')
            w_final = x_max - x_min
            h_final = y_max - y_min
            m = cv2.getPerspectiveTransform(point1, point2)
            target = cv2.warpPerspective(gauss_map, m, (w_final, h_final), cv2.INTER_NEAREST)
            for j in range(y_min, y_max):
                for k in range(x_min, x_max):
                    if target[j - y_min][k - x_min] > interval_label[j][k]:
                        interval_label[j, k] = target[j - y_min][k - x_min]
        if h > w:
            interval_label = np.rot90(interval_label, -1)
        interval_label = cv2.resize(interval_label, (300, 200), cv2.INTER_NEAREST)
        interval_label = cvt2HeatmapMatrix(interval_label)
        return interval_label

### Dataloader

In [9]:
class ImageLoader_synthtext(data.Dataset):
    def __init__(self):
        self.dataset = SynthText()

    def __len__(self):
        return self.dataset.len()

    def __getitem__(self, index):
        img_path = './data/SynthText/' + self.dataset.name[index].replace(" ", "")
        img, img_size = self.dataset.im_read_resize(img_path)
        char_label = self.dataset.char_label_generate(self.dataset.gauss_map, img_size, self.dataset.cor_list[index])
        interval_list = interval_list_generate(self.dataset.text[index])
        interval_label = self.dataset.interval_label_generate(self.dataset.gauss_map, img_size,
                                                              self.dataset.cor_list[index], interval_list)
        img, char_label, interval_label = random_augmentation(img, char_label, interval_label)
        img = torch.Tensor(img)
        char_label = torch.Tensor(char_label)
        interval_label = torch.Tensor(interval_label)
        return img, char_label, interval_label

In [10]:
def collate(batch):
    imgs = []
    char_labels = []
    interval_labels = []
    for sample in batch:
        imgs.append(sample[0])
        char_labels.append(sample[1])
        interval_labels.append(sample[2])
    imgs_stack = torch.stack(imgs, 0)
    char_labels_stack = torch.stack(char_labels, 0)
    interval_labels_stack = torch.stack(interval_labels, 0)
    return imgs_stack.permute(0, 3, 1, 2), char_labels_stack, interval_labels_stack


def random_augmentation(image, char_label, interval_label):
    f = ImageTransfer(image, char_label, interval_label)
    seed = random.randint(0, 5)  # 0: original image used
    if 0 < seed < 5:
        methods = ['rotate', 'add_noise', 'change_contrast', 'change_hsv']
        image, char_label, interval_label = getattr(f, methods[seed - 1])()
    return image, char_label, interval_label


class ImageTransfer(object):
    """add noise, rotate, change contrast, change_hsv"""

    def __init__(self, image, char_label, interval_label):
        """image: a ndarray with size [h, w, 3]"""
        """label: a ndarray with size [h/2, w/2]"""
        self.image = image
        self.char_label = char_label
        self.interval_label = interval_label

    def add_noise(self):
        img = self.image * (np.random.rand(*self.image.shape) * 0.4 + 0.6)
        img = img.astype(np.uint8)
        char_label = self.char_label
        interval_label = self.interval_label
        return img, char_label, interval_label

    def rotate(self, angle=None, center=None, scale=1.0, angle_min=20, angle_max=180):
        h, w = self.image.shape[:2]
        h1, w1 = self.char_label.shape
        if angle is None:
            angle = random.randint(angle_min, angle_max) if random.random() < 0.5 else random.randint(-angle_max,
                                                                                                      -angle_min)
        if center is None:
            center = (w // 2, h // 2)
            center1 = (w1 // 2, h1 // 2)
        M = cv2.getRotationMatrix2D(center, angle, scale)
        M1 = cv2.getRotationMatrix2D(center1, angle, scale)
        return cv2.warpAffine(self.image, M, (w, h)), cv2.warpAffine(self.char_label, M1, (w1, h1)), cv2.warpAffine(
            self.interval_label, M1, (w1, h1))

    def change_contrast(self):
        if random.random() < 0.5:
            k = random.randint(5, 9) / 10.0
        else:
            k = random.randint(11, 15) / 10.0
        b = 128 * (k - 1)
        img = self.image.astype(np.float)
        img = k * img - b
        img = np.maximum(img, 0)
        img = np.minimum(img, 255)
        img = img.astype(np.uint8)
        char_label = self.char_label
        interval_label = self.interval_label
        return img, char_label, interval_label

    def change_hsv(self):
        img = cv2.cvtColor(self.image, cv2.COLOR_BGR2HSV)
        char_label = self.char_label
        interval_label = self.interval_label
        s = random.random()

        def ch_h():
            dh = random.randint(2, 10) * random.randrange(-1, 2, 2)
            img[:, :, 0] = (img[:, :, 0] + dh) % 180

        def ch_s():
            ds = random.random() * 0.25 + 0.7
            img[:, :, 1] = ds * img[:, :, 1]

        def ch_v():
            dv = random.random() * 0.35 + 0.6
            img[:, :, 2] = dv * img[:, :, 2]

        if s < 0.25:
            ch_h()
        elif s < 0.50:
            ch_s()
        elif s < 0.75:
            ch_v()
        else:
            ch_h()
            ch_s()
            ch_v()
        return cv2.cvtColor(img, cv2.COLOR_HSV2BGR), char_label, interval_label

In [11]:
dataset = ImageLoader_synthtext()

data_loader = data.DataLoader(dataset, 4, num_workers=1, shuffle=True, collate_fn=collate)

### Model

In [12]:
def init_weights(modules):
    for m in modules:
        if isinstance(m, nn.Conv2d):
            init.xavier_uniform_(m.weight.data)
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, nn.BatchNorm2d):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.01)
            m.bias.data.zero_()


class vgg16_bn(torch.nn.Module):
    def __init__(self, pretrained=True, freeze=True):
        super(vgg16_bn, self).__init__()
        model_urls['vgg16_bn'] = model_urls['vgg16_bn'].replace('https://', 'http://')
        # model_urls['vgg16_bn'] = 'vgg16_bn-6c64b313.pth'
        vgg_pretrained_features = models.vgg16_bn(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(12):  # conv2_2
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 19):  # conv3_3
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(19, 29):  # conv4_3
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(29, 39):  # conv5_3
            self.slice4.add_module(str(x), vgg_pretrained_features[x])

        # fc6, fc7 without atrous conv
        self.slice5 = torch.nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6),
            nn.Conv2d(1024, 1024, kernel_size=1)
        )

        if not pretrained:
            init_weights(self.slice1.modules())
            init_weights(self.slice2.modules())
            init_weights(self.slice3.modules())
            init_weights(self.slice4.modules())

        init_weights(self.slice5.modules())  # no pretrained model for fc6 and fc7

        if freeze:
            for param in self.slice1.parameters():  # only first conv
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu2_2 = h
        h = self.slice2(h)
        h_relu3_2 = h
        h = self.slice3(h)
        h_relu4_3 = h
        h = self.slice4(h)
        h_relu5_3 = h
        h = self.slice5(h)
        h_fc7 = h
        vgg_outputs = namedtuple("VggOutputs", ['fc7', 'relu5_3', 'relu4_3', 'relu3_2', 'relu2_2'])
        out = vgg_outputs(h_fc7, h_relu5_3, h_relu4_3, h_relu3_2, h_relu2_2)
        return out


class double_conv(nn.Module):
    def __init__(self, in_ch, mid_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch + mid_ch, mid_ch, kernel_size=1),
            nn.BatchNorm2d(mid_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x


class CRAFT(nn.Module):
    def __init__(self, pretrained=False, freeze=False):
        super(CRAFT, self).__init__()

        """ Base network """
        self.basenet = vgg16_bn(pretrained, freeze)

        """ U network """
        self.upconv1 = double_conv(1024, 512, 256)
        self.upconv2 = double_conv(512, 256, 128)
        self.upconv3 = double_conv(256, 128, 64)
        self.upconv4 = double_conv(128, 64, 32)

        num_class = 2
        self.conv_cls = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(32, 16, kernel_size=3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(16, 16, kernel_size=1), nn.ReLU(inplace=True),
            nn.Conv2d(16, num_class, kernel_size=1),
        )

        init_weights(self.upconv1.modules())
        init_weights(self.upconv2.modules())
        init_weights(self.upconv3.modules())
        init_weights(self.upconv4.modules())
        init_weights(self.conv_cls.modules())

    def forward(self, x):
        """ Base network """
        sources = self.basenet(x)

        """ U network """
        y = torch.cat([sources[0], sources[1]], dim=1)
        y = self.upconv1(y)

        y = F.interpolate(y, size=sources[2].size()[2:], mode='bilinear', align_corners=False)
        y = torch.cat([y, sources[2]], dim=1)
        y = self.upconv2(y)

        y = F.interpolate(y, size=sources[3].size()[2:], mode='bilinear', align_corners=False)
        y = torch.cat([y, sources[3]], dim=1)
        y = self.upconv3(y)

        y = F.interpolate(y, size=sources[4].size()[2:], mode='bilinear', align_corners=False)
        y = torch.cat([y, sources[4]], dim=1)
        feature = self.upconv4(y)

        y = self.conv_cls(feature)

        return y.permute(0, 2, 3, 1), feature


### Loss and Optimizer

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
criterion = torch.nn.MSELoss(reduction='mean')
criterion = criterion.to(device)
craft = CRAFT(pretrained=True)
craft = craft.to(device)

optimizer = optim.Adam(craft.parameters(), lr=0.001)

cuda:0


### Train

In [14]:
def train_batch(data):
    div = 10
    craft.train()
    img, char_label, interval_label = data
    img = img.to(device)
    char_label = char_label.to(device)
    interval_label = interval_label.to(device)

    img.requires_grad_()
    optimizer.zero_grad()
    preds, _ = craft(img)
    cost_char = criterion(preds[:, :, :, 0], char_label).sum() / div
    cost_interval = criterion(preds[:, :, :, 1], interval_label).sum() / div
    cost = cost_char + cost_interval
    cost.backward()
    optimizer.step()
    return cost


loss_avg = averager()

for epoch in range(5):
    train_iter = iter(data_loader)
    i = 0
    while i < len(data_loader):
        time0 = time.time()
        data = train_iter.next()
        cost = train_batch(data)
        loss_avg.add(cost)
        i += 1

        # do checkpointing
        if i % 100 == 0:
            torch.save(craft.state_dict(),
                       '{0}/craft_{1}_{2}_{3}.pth'.format('store', epoch, i, loss_avg.val()))

        if i % 100 == 0:
            print('[%d/%d][%d/%d] lr: %.4f Loss: %f Time: %f s' %
                  (epoch, 5, i, len(data_loader), optimizer.param_groups[0]['lr'], loss_avg.val(),
                   time.time() - time0))
            loss_avg.reset()


[0/5][100/1113] lr: 0.0010 Loss: 154.313202 Time: 3.704025 s


KeyboardInterrupt: ignored