In [1]:
import torch, math, sys, os, io, lmdb
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision import transforms
from pathlib import Path
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from PIL import Image
from tqdm import tqdm
from datetime import datetime

import warnings
warnings.filterwarnings('ignore')

Add the root working directory.

In [2]:
PARENT_DIR = Path.cwd().parent
sys.path.insert(0, str(PARENT_DIR))

print(PARENT_DIR)

C:\Users\User\Downloads\text-super-resolution-network


Define the directory.

In [3]:
from utils.utils import DATASET_DIR, get_device, set_seed
from utils.ssim_psnr import calculate_psnr, SSIM
from utils.metrics import get_str_list, Accuracy
from utils.labelmaps import get_vocabulary
from models.recognizer.tps_spatial_transformer import TPSSpatialTransformer
from models.recognizer.stn_head import STNHead
from models.recognizer.recognizer_builder import RecognizerBuilder

DATASET = "TextZoom"
TRAIN1_DIR = os.path.join(DATASET_DIR, DATASET, "train1")
TRAIN2_DIR = os.path.join(DATASET_DIR, DATASET, "train2")
TEST1_DIR = os.path.join(DATASET_DIR, DATASET, "test", "easy")
TEST2_DIR = os.path.join(DATASET_DIR, DATASET, "test", "medium")
TEST3_DIR = os.path.join(DATASET_DIR, DATASET, "test", "hard")

print(TRAIN1_DIR)
print(TRAIN2_DIR)
print(TEST1_DIR)
print(TEST2_DIR)
print(TEST3_DIR)

C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\train1
C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\train2
C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\test\easy
C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\test\medium
C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\test\hard


In [4]:
SEED = 42
VOC_TYPE = "all"
set_seed(SEED)
device = get_device()

Device: cuda
CUDA available: True
CUDA version: 12.1
GPU: NVIDIA GeForce RTX 3070


### Build TextZoom dataset

Define the `TextZoomDataset` class.

In [5]:
from utils.utils import filter_str

class TextZoomDataset(Dataset):
    def __init__(self, data_dir=None, voc_type="upper", max_len=33):
        super().__init__()
        self.data_dir = data_dir
        self.voc_type = voc_type
        self.max_len = max_len

        env = lmdb.open(self.data_dir, readonly=True, lock=False, readahead=False, meminit=False)
        if not env:
            print('Cannot create lmdb from %s' % (data_dir))
            sys.exit(0)

        with env.begin(write=False) as txn:
            num_samples = int(txn.get(b'num-samples'))
            self.num_samples = num_samples
        env.close()

    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, index):       
        index += 1
        env = lmdb.open(self.data_dir, readonly=True, lock=False, readahead=False, meminit=False)
        
        with env.begin(write=False) as txn:
            hr_key = b'image_hr-%09d' % index
            lr_key = b'image_lr-%09d' % index
            label_key = b'label-%09d' % index

            hr_buffer = txn.get(hr_key)
            lr_buffer = txn.get(lr_key)
            label_buffer = txn.get(label_key)

            # error handling: if data is missing
            if lr_buffer is None or hr_buffer is None or label_buffer is None:
                return self.__getitem__(index)

            # convert Bytes to PIL Image
            img_hr = Image.open(io.BytesIO(hr_buffer)).convert('RGB')
            img_lr = Image.open(io.BytesIO(lr_buffer)).convert('RGB')

            # decode label
            label = str(label_buffer.decode())
            label = filter_str(label, self.voc_type)

            return img_hr, img_lr, label

Define the `ResizeNormalize` class.

In [6]:
class resizeNormalize(object):
    def __init__(self, size, mask=False, interpolation=Image.BICUBIC):
        self.size = size
        self.interpolation = interpolation
        self.toTensor = transforms.ToTensor()
        self.mask = mask

    def __call__(self, img):
        img = img.resize(self.size, self.interpolation)
        img_tensor = self.toTensor(img)
        if self.mask:
            mask = img.convert('L')
            thres = np.array(mask).mean()
            mask = mask.point(lambda x: 0 if x > thres else 255)
            mask = self.toTensor(mask)
            img_tensor = torch.cat((img_tensor, mask), 0)
        return img_tensor

Define the `AlignCollate` class.

In [7]:
class AlignCollate():
    def __init__(self, imgH=64, imgW=256, down_sample_scale=4, keep_ratio=False, min_ratio=1, mask=False):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio = keep_ratio
        self.min_ratio = min_ratio
        self.down_sample_scale = down_sample_scale
        self.mask = mask

    def __call__(self, batch):
        img_hr, img_lr, label = zip(*batch)
        imgH = self.imgH
        imgW = self.imgW
        transform = resizeNormalize((imgW, imgH), self.mask)
        transform2 = resizeNormalize((imgW // self.down_sample_scale, imgH // self.down_sample_scale), self.mask)
        img_hr = [transform(image) for image in img_hr]
        img_hr = torch.cat([t.unsqueeze(0) for t in img_hr], 0)

        img_lr = [transform2(image) for image in img_lr]
        img_lr = torch.cat([t.unsqueeze(0) for t in img_lr], 0)

        return img_hr, img_lr, label

### Build TSRN architecture

Define the `GRUBlock` class.

In [8]:
class GRUBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GRUBlock, self).__init__()
        assert out_channels % 2 == 0
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0)
        self.gru = nn.GRU(out_channels, out_channels // 2, bidirectional=True, batch_first=True)

    def forward(self, x):
        x = self.conv1(x)
        x = x.permute(0, 2, 3, 1).contiguous()
        b = x.size()
        x = x.view(b[0] * b[1], b[2], b[3])
        x, _ = self.gru(x)
        # x = self.gru(x)[0]
        x = x.view(b[0], b[1], b[2], b[3])
        x = x.permute(0, 3, 1, 2)
        return x

Define the `MISH` class.

In [9]:
class MISH(nn.Module):
    def __init__(self, ):
        super(MISH, self).__init__()
        self.activated = True

    def forward(self, x):
        if self.activated:
            x = x * (torch.tanh(F.softplus(x)))
        return x

Define the `RecurrentResidualBLock` class.

In [10]:
class RecurrentResidualBlock(nn.Module):
    def __init__(self, channels):
        super(RecurrentResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.gru1 = GRUBlock(channels, channels)
        # self.prelu = nn.ReLU()
        self.prelu = MISH()
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
        self.gru2 = GRUBlock(channels, channels)

    def forward(self, x):
        residual = self.conv1(x)
        residual = self.bn1(residual)
        residual = self.prelu(residual)
        residual = self.conv2(residual)
        residual = self.bn2(residual)
        residual = self.gru1(residual.transpose(-1, -2)).transpose(-1, -2)
        # residual = self.non_local(residual)

        return self.gru2(x + residual)

Define the `UpsampleBlock` class.

In [11]:
class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)

        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        # self.prelu = nn.ReLU()
        self.prelu = MISH()

    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x

Define the `TSRN` class.

In [12]:
class TSRN(nn.Module):
    def __init__(self, scale_factor=2, width=128, height=32, STN=False, srb_nums=5, mask=True, hidden_units=32):
        super(TSRN, self).__init__()
        in_planes = 3
        if mask:
            in_planes = 4
        assert math.log(scale_factor, 2) % 1 == 0
        upsample_block_num = int(math.log(scale_factor, 2))
        self.block1 = nn.Sequential(
            nn.Conv2d(in_planes, 2*hidden_units, kernel_size=9, padding=4),
            nn.PReLU()
            # nn.ReLU()
        )
        self.srb_nums = srb_nums
        for i in range(srb_nums):
            setattr(self, 'block%d' % (i + 2), RecurrentResidualBlock(2*hidden_units))

        setattr(self, 'block%d' % (srb_nums + 2),
                nn.Sequential(
                    nn.Conv2d(2*hidden_units, 2*hidden_units, kernel_size=3, padding=1),
                    nn.BatchNorm2d(2*hidden_units)
                ))
        
        # self.non_local = NonLocalBlock2D(64, 64)
        block_ = [UpsampleBLock(2*hidden_units, 2) for _ in range(upsample_block_num)]
        block_.append(nn.Conv2d(2*hidden_units, in_planes, kernel_size=9, padding=4))
        setattr(self, 'block%d' % (srb_nums + 3), nn.Sequential(*block_))
        self.tps_inputsize = [32, 64]
        tps_outputsize = [height//scale_factor, width//scale_factor]
        num_control_points = 20
        tps_margins = [0.05, 0.05]
        self.stn = STN
        if self.stn:
            self.tps = TPSSpatialTransformer(
                output_image_size=tuple(tps_outputsize),
                num_control_points=num_control_points,
                margins=tuple(tps_margins))

            self.stn_head = STNHead(
                in_planes=in_planes,
                num_ctrlpoints=num_control_points,
                activation='none')

    def forward(self, x):
        # embed()
        if self.stn and self.training:
            x = F.interpolate(x, self.tps_inputsize, mode='bilinear', align_corners=True)
            _, ctrl_points_x = self.stn_head(x)
            x, _ = self.tps(x, ctrl_points_x)
        block = {'1': self.block1(x)}
        for i in range(self.srb_nums + 1):
            block[str(i + 2)] = getattr(self, 'block%d' % (i + 2))(block[str(i + 1)])

        block[str(self.srb_nums + 3)] = getattr(self, 'block%d' % (self.srb_nums + 3)) \
            ((block['1'] + block[str(self.srb_nums + 2)]))
        output = torch.tanh(block[str(self.srb_nums + 3)])
        return output

### Build Aster architecture

Define `AsterInfo` class.

In [13]:
class AsterInfo(object):
    def __init__(self, voc_type):
        super(AsterInfo, self).__init__()
        self.voc_type = voc_type
        assert voc_type in ['digit', 'lower', 'upper', 'all']
        self.EOS = 'EOS'
        self.max_len = 100
        self.PADDING = 'PADDING'
        self.UNKNOWN = 'UNKNOWN'
        self.voc = get_vocabulary(voc_type, EOS=self.EOS, PADDING=self.PADDING, UNKNOWN=self.UNKNOWN)
        self.char2id = dict(zip(self.voc, range(len(self.voc))))
        self.id2char = dict(zip(range(len(self.voc)), self.voc))
        self.rec_num_classes = len(self.voc)

Define the `aster_init` function.

In [14]:
def aster_init():
    ASTER_PATH = PARENT_DIR.joinpath("aster.pth")
    aster_info = AsterInfo(VOC_TYPE)
    aster = RecognizerBuilder(arch='ResNet_ASTER', rec_num_classes=aster_info.rec_num_classes,
                                         sDim=512, attDim=512, max_len_labels=aster_info.max_len,
                                         eos=aster_info.char2id[aster_info.EOS], STN_ON=True)
    aster.load_state_dict(torch.load(ASTER_PATH)['state_dict'])
    aster = aster.to(device)

    for p in aster.parameters():
        p.requires_grad = False
    aster.eval()
    
    print(f"load pred_trained aster model from %s" % {ASTER_PATH})
    return aster, aster_info

Define the `parse_aster_data` function.

In [15]:
def parse_aster_data(aster_info, imgs_input):
    input_dict = {}
    images_input = imgs_input.to(device)
    input_dict['images'] = images_input * 2 - 1
    batch_size = images_input.shape[0]
    input_dict['rec_targets'] = torch.LongTensor(batch_size, aster_info.max_len).fill_(1).to(device)
    input_dict['rec_lengths'] = torch.LongTensor([aster_info.max_len] * batch_size).to(device)
    return input_dict

### Build loss function

Define the `GradientPriorLoss` class.

In [16]:
class GradientPriorLoss(nn.Module):
    def __init__(self, ):
        super(GradientPriorLoss, self).__init__()
        self.func = nn.L1Loss()

    def forward(self, out_images, target_images):
        map_out = self.gradient_map(out_images)
        map_target = self.gradient_map(target_images)
        return self.func(map_out, map_target)

    @staticmethod
    def gradient_map(x):
        batch_size, channel, h_x, w_x = x.size()
        r = F.pad(x, (0, 1, 0, 0))[:, :, :, 1:]
        l = F.pad(x, (1, 0, 0, 0))[:, :, :, :w_x]
        t = F.pad(x, (0, 0, 1, 0))[:, :, :h_x, :]
        b = F.pad(x, (0, 0, 0, 1))[:, :, 1:, :]
        xgrad = torch.pow(torch.pow((r - l) * 0.5, 2) + torch.pow((t - b) * 0.5, 2)+1e-6, 0.5)
        return xgrad

Define the `ImageLoss` class.

In [17]:
class ImageLoss(nn.Module):
    def __init__(self, gradient=True, loss_weight=[20, 1e-4]):
        super(ImageLoss, self).__init__()
        self.mse = nn.MSELoss()
        if gradient:
            self.GPLoss = GradientPriorLoss()
        self.gradient = gradient
        self.loss_weight = loss_weight

    def forward(self, out_images, target_images):
        if self.gradient:
            loss = self.loss_weight[0] * self.mse(out_images, target_images) + \
                   self.loss_weight[1] * self.GPLoss(out_images[:, :3, :, :], target_images[:, :3, :, :])
        else:
            loss = self.loss_weight[0] * self.mse(out_images, target_images)
        return loss

### Training phrase

Setting up the training.

In [18]:
LEARNING_RATE = 0.001
BETA1 = 0.5
EPOCHS = 500
BATCH_SIZE = 128

In [19]:
align_collate = AlignCollate(imgH=32, imgW=128, down_sample_scale=2, mask=False)

# load the training dataset
train1_dataset, train2_dataset = TextZoomDataset(TRAIN1_DIR), TextZoomDataset(TRAIN2_DIR)
train_dataset = ConcatDataset([train1_dataset, train2_dataset])
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, collate_fn=align_collate)

# load the testing dataset
test1_dataset, test2_dataset, test3_dataset = TextZoomDataset(TEST1_DIR), TextZoomDataset(TEST2_DIR), TextZoomDataset(TEST3_DIR)
val_dataset = ConcatDataset([test1_dataset, test2_dataset, test3_dataset])
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, collate_fn=align_collate)

In [20]:
model = TSRN(scale_factor=2, width=128, height=32, STN=True, mask=False, srb_nums=5, hidden_units=32).to(device)
aster, aster_info = aster_init()
criterion = ImageLoss(gradient=True, loss_weight=[1, 1e-4])
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(BETA1, 0.999))

calculate_ssim = SSIM()

load pred_trained aster model from {WindowsPath('C:/Users/User/Downloads/text-super-resolution-network/aster.pth')}


Start the training.

In [None]:
metrics_list = []
for epoch in range(EPOCHS):
    # ------------ training phrase ------------
    model.train()
    running_loss = 0
    batch_iterator = tqdm(train_loader, leave=False, desc=f'Processing epoch {epoch+1:02d}')

    for j, batch in (enumerate(batch_iterator)):
        img_hr, img_lr, label = batch
        img_hr = img_hr.to(device)
        img_lr = img_lr.to(device)

        # forward pass
        img_sr = model(img_lr)
        loss = criterion(img_sr, img_hr).mean() * 100

        optimizer.zero_grad()

        # backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
        optimizer.step()

        # update statistics
        running_loss += loss.item()
        
        # update progress bar
        batch_iterator.set_postfix(loss=f"{loss.item():.5f}", lr=f"{optimizer.param_groups[0]['lr']:.2e}")

    # calculate average metrics
    train_loss = running_loss / len(train_loader)

    # ------------ validation phrase ------------
    model.eval()
    running_loss = 0
    n_correct, total_images = 0, 0
    psnr_list, ssim_list = [], []
    batch_iterator = tqdm(val_loader, leave=False, desc=f'Validating epoch {epoch+1:02d}')
    
    with torch.no_grad():
        for batch in batch_iterator:
            img_hr, img_lr, label = batch
            batch_size = img_lr.shape[0]
            img_hr = img_hr.to(device)
            img_lr = img_lr.to(device)

            # forward pass
            img_sr = model(img_lr)
            loss = criterion(img_sr, img_hr).mean() * 100

            # calculate image quality metrics
            psnr_list.append(calculate_psnr(img_sr, img_hr))
            ssim_list.append(calculate_ssim(img_sr, img_hr))

            # prepare data for ASTER text recognition
            aster_dict_sr = parse_aster_data(aster_info, img_sr[:, :3, :, :])
            aster_dict_lr = parse_aster_data(aster_info, img_lr[:, :3, :, :])

            # get text predictions
            aster_output_lr = aster(aster_dict_lr)
            aster_output_sr = aster(aster_dict_sr)

            pred_rec_lr = aster_output_lr['output']['pred_rec']
            pred_rec_sr = aster_output_sr['output']['pred_rec']

            pred_str_lr, _ = get_str_list(pred_rec_lr, aster_dict_lr['rec_targets'], dataset=aster_info)
            pred_str_sr, _ = get_str_list(pred_rec_sr, aster_dict_sr['rec_targets'], dataset=aster_info)

            # calculate text recognition accuracy
            for pred, target in zip(pred_str_sr, label):
                if pred == filter_str(target, VOC_TYPE):
                    n_correct += 1

            # calculate losses (for logging only)
            running_loss += loss.item()
            
            total_images += batch_size

            # update progress bar
            batch_iterator.set_postfix(loss=f"{loss.item():.4f}")
            
    # calculate average metrics
    val_loss = running_loss / len(val_loader)
    psnr_avg = sum(psnr_list) / len(psnr_list)
    ssim_avg = sum(ssim_list) / len(ssim_list)
    accuracy = n_correct / total_images
    
    # print results
    print(f'[{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}]\t'
          f'EPOCH: {epoch+1}/{EPOCHS}')
    print(f'ASTER accuracy: {accuracy * 100:.2f}%')
    print(f'train_loss {train_loss:.5f} | val_loss {val_loss:.5f}')
    print(f'PSNR {psnr_avg.item():.2f} | SSIM {ssim_avg.item():.4f}')

    metrics_list.append({
        'accuracy': round(accuracy, 4),
        'psnr_avg': round(psnr_avg.item(), 6),
        'ssim_avg': round(ssim_avg.item(), 6),
        'psnr': psnr_list,
        'ssim': ssim_list
    })

Validating epoch 01:  26%|████████████▌                                    | 9/35 [02:30<07:00, 16.17s/it, loss=1.6260]