In [1]:
from pathlib import Path
from shutil import rmtree
import pdb

from tqdm import tqdm
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import pandas as pd

import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import transforms as T
import utils

from label_babel_dataset import LabelBabelDataset

In [2]:
DATA_DIR = Path('.') / 'data'

TRN_DIR = DATA_DIR / 'train'
VAL_DIR = DATA_DIR / 'valid'
MODEL_DIR = DATA_DIR / 'models'
HAND_DIR = DATA_DIR / 'handwritten'
TYPE_DIR = DATA_DIR / 'typewritten'

VAL_CSV = DATA_DIR / 'valid.csv'
TRN_CSV = DATA_DIR / 'train.csv'

BOX = 'box'
CAT = 'category'
ORI = 'original'
PATH = 'path'
CROP = 'predicted_box_original'
CLASS = 'class'
SUB_ID = 'subject_id'
PRED_BOX = 'predicted_box'
PRED_CAT = 'predicted_category'
PRED_CLASS = 'predicted_class'

DEVICE = torch.device('cuda')
torch.backends.cudnn.benchmark = True

HANDWRITTEN = 'handwritten'
TYPEWRITTEN = 'typewritten'
CATS = ['background', HANDWRITTEN, TYPEWRITTEN]
CLASSES = len(CATS)

CHECKPOINT = 'checkpoint_{}.pth.tar'

SCALE = 4.0

In [3]:
def load_df(df_path):
    df = pd.read_csv(df_path, index_col='subject_id').reset_index()
    df[PRED_CLASS] = None
    df[PRED_CAT] = None
    df[PRED_BOX] = None
    df[CROP] = None
    return df  # df.iloc[:2, :].copy()

In [4]:
val_df = load_df(VAL_CSV)
trn_df = load_df(TRN_CSV)

In [5]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, CLASSES)

In [6]:
transforms = T.Compose([T.ToTensor()])

val_dataset = LabelBabelDataset(val_df, transforms)
trn_dataset = LabelBabelDataset(trn_df, transforms)

In [7]:
model.to(DEVICE)
model.eval();

In [8]:
last_checkpoint = sorted(MODEL_DIR.glob(CHECKPOINT.format('*')))[-1]

state = torch.load(last_checkpoint)

model.load_state_dict(state['state_dict'])

state['epoch']

5

In [9]:
def get_predictions(dataset, df):
    for idx, (image, target) in tqdm(enumerate(dataset)):
        with torch.no_grad():
            prediction = model([image.to(DEVICE)])

        p_class = 0
        p_box = [0, 0, 0, 0]
        if len(prediction[0]['labels'].cpu()):
            p_class = prediction[0]['labels'].cpu()[0].numpy()
            p_box = prediction[0]['boxes'].cpu()[0].numpy()

        df.at[idx, PRED_BOX] = p_box
        df.at[idx, PRED_CLASS] = p_class


get_predictions(val_dataset, val_df)
get_predictions(trn_dataset, trn_df)

1216it [02:19,  8.70it/s]
4865it [08:58,  9.03it/s]


In [10]:
cats = pd.Series(CATS)

trn_df[PRED_CAT] = trn_df[PRED_CLASS].astype(int).map(cats)
val_df[PRED_CAT] = val_df[PRED_CLASS].astype(int).map(cats)

val_df.head()

Unnamed: 0,subject_id,category,class,box,path,original,predicted_class,predicted_category,predicted_box,predicted_text,predicted_box_original
0,2995202,typewritten,2,"[228, 459, 367, 568]",data/valid/2995202.jpg,data/images/2995202.jpg,2,typewritten,"[230.9855, 463.58728, 365.75626, 570.29736]",FLORA OF ARK:\n\nCORNACEAE Cornus drummondii €...,
1,2995203,typewritten,2,"[238, 462, 377, 572]",data/valid/2995203.jpg,data/images/2995203.jpg,2,typewritten,"[235.45955, 464.4702, 376.1618, 573.8489]",FLORA OF ARKANSAS.\nCULTIVATED!\nCORNACEAE\n\n...,
2,2995205,typewritten,2,"[252, 486, 373, 564]",data/valid/2995205.jpg,data/images/2995205.jpg,1,handwritten,"[255.06477, 481.98413, 373.10254, 568.11176]",,
3,2995213,typewritten,2,"[224, 474, 373, 566]",data/valid/2995213.jpg,data/images/2995213.jpg,2,typewritten,"[224.79485, 476.49863, 370.26642, 565.2874]",CORNACEAE\nHERBARIUM OF THE UNIVERSITY OF ARKA...,
4,2995216,handwritten,1,"[239, 462, 367, 562]",data/valid/2995216.jpg,data/images/2995216.jpg,1,handwritten,"[235.82883, 460.7853, 369.38535, 562.74536]",,


In [11]:
rmtree(HAND_DIR, ignore_errors=True)
rmtree(TYPE_DIR, ignore_errors=True)

HAND_DIR.mkdir()
TYPE_DIR.mkdir()

In [12]:
def crop_image(df):
    for idx, row in tqdm(df.iterrows()):
        path = Path(row.at[ORI])
        src = str(path)
        dst = TYPE_DIR if row.at[PRED_CAT] == TYPEWRITTEN else HAND_DIR
        dst = dst / path.name
        box = [int(round(x * SCALE)) for x in row.at[PRED_BOX]]
        df.at[idx, CROP] = box
        if sum(x for x in box) == 0:
            continue
        image = Image.open(src).convert('RGB')
        width, height = image.size
        image = image.crop(box)
        if width > height:
            image = image.transpose(Image.ROTATE_270)
        image.save(dst)


crop_image(val_df)
crop_image(trn_df)

1216it [01:13, 16.50it/s]
4865it [04:29, 18.05it/s]


In [13]:
val_df.to_csv(VAL_CSV, index=False)
trn_df.to_csv(TRN_CSV, index=False)