Import dependencies.

In [1]:
import os, io, lmdb, sys, torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from PIL import Image
from pathlib import Path

import warnings
warnings.filterwarnings('ignore')

Define the directory.

In [2]:
PARENT_DIR = Path.cwd().parent
sys.path.insert(0, str(PARENT_DIR))

print(PARENT_DIR)

C:\Users\User\Downloads\text-super-resolution-network


In [3]:
from utils.utils import DATASET_DIR

DATASET = "TextZoom"
TRAIN1_DIR = os.path.join(DATASET_DIR, DATASET, "train1")
TRAIN2_DIR = os.path.join(DATASET_DIR, DATASET, "train2")
TEST1_DIR = os.path.join(DATASET_DIR, DATASET, "test", "easy")
TEST2_DIR = os.path.join(DATASET_DIR, DATASET, "test", "medium")
TEST3_DIR = os.path.join(DATASET_DIR, DATASET, "test", "hard")

print(TRAIN1_DIR)
print(TRAIN2_DIR)
print(TEST1_DIR)
print(TEST2_DIR)
print(TEST3_DIR)

C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\train1
C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\train2
C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\test\easy
C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\test\medium
C:\Users\User\Downloads\text-super-resolution-network\dataset\TextZoom\test\hard


Define the `TextZoomDataset` class.

In [4]:
from utils.utils import filter_str

class TextZoomDataset(Dataset):
    def __init__(self, data_dir=None, voc_type="upper", max_len=33):
        super().__init__()
        self.data_dir = data_dir
        self.voc_type = voc_type
        self.max_len = max_len

        env = lmdb.open(self.data_dir, readonly=True, lock=False, readahead=False, meminit=False)
        if not env:
            print('Cannot create lmdb from %s' % (data_dir))
            sys.exit(0)

        with env.begin(write=False) as txn:
            num_samples = int(txn.get(b'num-samples'))
            self.num_samples = num_samples
        env.close()

    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, index):       
        index += 1
        env = lmdb.open(self.data_dir, readonly=True, lock=False, readahead=False, meminit=False)
        
        with env.begin(write=False) as txn:
            hr_key = b'image_hr-%09d' % index
            lr_key = b'image_lr-%09d' % index
            label_key = b'label-%09d' % index

            hr_buffer = txn.get(hr_key)
            lr_buffer = txn.get(lr_key)
            label_buffer = txn.get(label_key)

            # error handling: if data is missing
            if lr_buffer is None or hr_buffer is None or label_buffer is None:
                return self.__getitem__(index)

            # convert Bytes to PIL Image
            img_hr = Image.open(io.BytesIO(hr_buffer)).convert('RGB')
            img_lr = Image.open(io.BytesIO(lr_buffer)).convert('RGB')

            # decode label
            label = str(label_buffer.decode())
            label = filter_str(label, self.voc_type)

            return img_hr, img_lr, label

Define the `ResizeNormalize` class.

In [5]:
class resizeNormalize(object):
    def __init__(self, size, mask=False, interpolation=Image.BICUBIC):
        self.size = size
        self.interpolation = interpolation
        self.toTensor = transforms.ToTensor()
        self.mask = mask

    def __call__(self, img):
        img = img.resize(self.size, self.interpolation)
        img_tensor = self.toTensor(img)
        if self.mask:
            mask = img.convert('L')
            thres = np.array(mask).mean()
            mask = mask.point(lambda x: 0 if x > thres else 255)
            mask = self.toTensor(mask)
            img_tensor = torch.cat((img_tensor, mask), 0)
        return img_tensor

Define the `AlignCollate` class.

In [6]:
class AlignCollate():
    def __init__(self, imgH=64, imgW=256, down_sample_scale=4, keep_ratio=False, min_ratio=1, mask=False):
        self.imgH = imgH
        self.imgW = imgW
        self.keep_ratio = keep_ratio
        self.min_ratio = min_ratio
        self.down_sample_scale = down_sample_scale
        self.mask = mask

    def __call__(self, batch):
        img_hr, img_lr, label = zip(*batch)
        imgH = self.imgH
        imgW = self.imgW
        transform = resizeNormalize((imgW, imgH), self.mask)
        transform2 = resizeNormalize((imgW // self.down_sample_scale, imgH // self.down_sample_scale), self.mask)
        img_hr = [transform(image) for image in img_hr]
        img_hr = torch.cat([t.unsqueeze(0) for t in img_hr], 0)

        img_lr = [transform2(image) for image in img_lr]
        img_lr = torch.cat([t.unsqueeze(0) for t in img_lr], 0)

        return img_hr, img_lr, label

### Playground Section

Test out the TextZoomDataset class.

In [7]:
align_collate = AlignCollate(imgH=32, imgW=128, down_sample_scale=2, mask=False)

train1_dataset, train2_dataset = TextZoomDataset(TRAIN1_DIR), TextZoomDataset(TRAIN2_DIR)
dataset = ConcatDataset([train1_dataset, train2_dataset])
loader = DataLoader(dataset, batch_size=512, shuffle=True, collate_fn=align_collate)

for batch in loader:
    img_hr_sample = batch[0][0]  # HR image
    img_lr_sample = batch[1][0]  # LR image
    label_sample = batch[2][0]   # label

    print(f"HR image shape: {img_hr_sample.shape}")
    print(f"LR image shape: {img_lr_sample.shape}")
    print(f"Label: {label_sample}")
    break

HR image shape: torch.Size([3, 32, 128])
LR image shape: torch.Size([3, 16, 64])
Label: core


In [8]:
print(f"Total train1: {len(train1_dataset)}")
print(f"Total train2: {len(train2_dataset)}")
print(f"Total (train1+train2): {len(dataset)}")

Total train1: 14573
Total train2: 2794
Total (train1+train2): 17367


Test out on the test (easy, medium, hard).

In [9]:
align_collate = AlignCollate(imgH=32, imgW=128, down_sample_scale=2, mask=False)

test1_dataset, test2_dataset, test3_dataset = TextZoomDataset(TEST1_DIR), TextZoomDataset(TEST2_DIR), TextZoomDataset(TEST3_DIR)
dataset = ConcatDataset([test1_dataset, test2_dataset, test3_dataset])
loader = DataLoader(dataset, batch_size=512, shuffle=True, collate_fn=align_collate)

for batch in loader:
    img_hr_sample = batch[0][0]  # HR image
    img_lr_sample = batch[1][0]  # LR image
    label_sample = batch[2][0]   # label

    print(f"HR image shape: {img_hr_sample.shape}")
    print(f"LR image shape: {img_lr_sample.shape}")
    print(f"Label: {label_sample}")
    break

HR image shape: torch.Size([3, 32, 128])
LR image shape: torch.Size([3, 16, 64])
Label: look


In [10]:
print(f"Total test1: {len(test1_dataset)}")
print(f"Total test2: {len(test2_dataset)}")
print(f"Total test3: {len(test3_dataset)}")
print(f"Total (test1+test2+test3): {len(dataset)}")

Total test1: 1619
Total test2: 1411
Total test3: 1343
Total (test1+test2+test3): 4373
