# Libraries

In [None]:
%load_ext autoreload
%autoreload 2

import os, json, cv2
import numpy as np
from glob import glob
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import sys
sys.path.append("../")
from FOTS.model.keys import keys
from FOTS.data_loader.data_loaders import SynthTextDataLoaderFactory
from FOTS.utils.util import get_center_bbox, get_center_bboxes, draw_texts, strLabelConverter

# Build dataloader

In [None]:
config = json.load(open("../debug.json"))
data_loader = SynthTextDataLoaderFactory(config)
train_loader = data_loader.train()
labelConverter = strLabelConverter(keys)

# Visualize

In [None]:
# Visualize detection information

for data_idx, data in enumerate(train_loader):

    image_files, images, score_maps, geo_maps, training_masks, transcripts, rectangles, mappings = data

    plt.figure(figsize=(30,5*len(images)))
    print("images", images.shape)
    print("score_maps", score_maps.shape)
    print("geo_maps", geo_maps.shape)
    print("training_masks", training_masks.shape)
    print("transcripts", len(transcripts))
    print("rectangles", rectangles.shape)
    print("mappings", mappings.shape)
    print("transcripts", transcripts)
    print("mappings", mappings)

    for image_idx, (image_file, image, score_map, geo_map, training_mask) in enumerate(zip(image_files, images, score_maps, geo_maps, training_masks)):
        # print(image.shape, score_map.shape, geo_map.shape, training_mask.shape)

        # Get image
        image = image.numpy().transpose((1,2,0))
        image = (image * np.array([0.229, 0.224, 0.225])[None,None,:] + np.array([0.485, 0.456, 0.406])[None,None,:])
        image = (255*image).astype(np.uint8)
        h, w = image.shape[:2]

        # Get score_map
        score_map = score_map.numpy().astype(np.uint8)[0]
        score_map = cv2.resize(score_map, (w,h))
        # print('score_map', score_map.min(), score_map.max())

        # Get geo_map
        geo_map = geo_map.numpy().transpose((1,2,0))
        geo_map = cv2.resize(geo_map, (w,h))
        # print('geo_map', geo_map.min(), geo_map.max())

        # Get training_mask
        training_mask = training_mask.numpy().astype(np.uint8)[0]
        training_mask = cv2.resize(training_mask, (w,h))

        # Visualize
        plt.subplot(len(images), 8, 8*image_idx+1)
        plt.imshow(image); plt.title(os.path.basename(image_file)); plt.axis('off')

        plt.subplot(len(images), 8, 8*image_idx+2)
        plt.imshow(image); plt.imshow(score_map, alpha=0.5, vmin=0, vmax=1); plt.title("score_map"); plt.axis('off')
        
        plt.subplot(len(images), 8, 8*image_idx+3)
        plt.imshow(image); plt.imshow(training_mask, alpha=0.5, vmin=0, vmax=1); plt.title("training_mask"); plt.axis('off')
        
        plt.subplot(len(images), 8, 8*image_idx+4)
        plt.imshow(image); plt.imshow(geo_map[...,0], alpha=0.5, vmin=0, vmax=geo_map[...,0].max()); plt.title("geo_map1"); plt.axis('off')
        
        plt.subplot(len(images), 8, 8*image_idx+5)
        plt.imshow(image); plt.imshow(geo_map[...,1], alpha=0.5, vmin=0, vmax=geo_map[...,1].max()); plt.title("geo_map2"); plt.axis('off')
        
        plt.subplot(len(images), 8, 8*image_idx+6)
        plt.imshow(image); plt.imshow(geo_map[...,2], alpha=0.5, vmin=0, vmax=geo_map[...,2].max()); plt.title("geo_map3"); plt.axis('off')
        
        plt.subplot(len(images), 8, 8*image_idx+7)
        plt.imshow(image); plt.imshow(geo_map[...,3], alpha=0.5, vmin=0, vmax=geo_map[...,3].max()); plt.title("geo_map4"); plt.axis('off')
        
        plt.subplot(len(images), 8, 8*image_idx+8)
        plt.imshow(image); plt.imshow(geo_map[...,4], alpha=0.5, vmin=0.0, vmax=3.14/2); plt.title("geo_map5"); plt.axis('off')

    # Break and show
    if data_idx==0:
        break
    plt.show()

In [None]:
# Visualize recognition information

for data_idx, data in enumerate(train_loader):

    image_files, images, score_maps, geo_maps, training_masks, transcripts, rectangles, mappings = data

    plt.figure(figsize=(20,10*len(images)))
    print("images", images.shape)
    print("score_maps", score_maps.shape)
    print("geo_maps", geo_maps.shape)
    print("training_masks", training_masks.shape)
    print("transcripts", len(transcripts))
    print("rectangles", rectangles.shape)
    print("mappings", mappings.shape)
    print("transcripts", transcripts)
    print("mappings", mappings)

    for image_idx, (image_file, image, score_map, geo_map, training_mask) in enumerate(zip(image_files, images, score_maps, geo_maps, training_masks)):

        # Get image
        image = image.numpy().transpose((1,2,0))
        image = (image * np.array([0.229, 0.224, 0.225])[None,None,:] + np.array([0.485, 0.456, 0.406])[None,None,:])
        image = (255*image).astype(np.uint8)
        h, w = image.shape[:2]

        # Get transcript
        bbox_indicator = (mappings==image_idx)
        transcript = transcripts[bbox_indicator]
        rectangle = rectangles[bbox_indicator]
        xc, yc = get_center_bboxes(rectangle)
        image_text = draw_texts(image, transcript, xc, yc, fontsize=3, color=(255,255,0), thickness=3)

        # Visualize
        plt.subplot(len(images), 2, 2*image_idx+1)
        plt.imshow(image); plt.title(os.path.basename(image_file)); plt.axis('off')

        plt.subplot(len(images), 2, 2*image_idx+2)
        plt.imshow(image_text); plt.title("OCR"); plt.axis('off')

    # Break and show
    if data_idx==0:
        break
    plt.show()