In [1]:
import sys

sys.path.append('..')

In [2]:
from collections import defaultdict, namedtuple

import torch

from allometry.util import finished, started
from allometry.model_util import load_model_state, get_model
from allometry.allometry_sheet import AllometrySheet
from allometry.const import DATA_DIR, BBox
from allometry.characters import IDX_TO_CHAR
from torch.utils.data import DataLoader

In [3]:
MODEL_DIR = DATA_DIR / 'model'
MODEL_STATE = 'best_resnext101_2021-04-08_context_c.pth'
DEVICE = 'cuda:0'
BATCH_SIZE = 20
MODEL_ARCH = 'resnext101'
WORKERS = 4

SHEET = DATA_DIR / 'allometry_sheets'
SHEET = SHEET / 'Biomass_Fish_Families_FamilyProgram'
SHEET = SHEET / '00001.tif'

Char = namedtuple('Char', 'char bbox')

ROTATE = 90

In [4]:
def get_loader():
    """Get the data loader."""
    dataset = AllometrySheet(SHEET, rotate=ROTATE)
    return DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=WORKERS)

In [5]:
def batches(model, device, loader, page):
    """Test the model."""
    model.eval()
    with torch.no_grad():
        for x, box in loader:
            x = x.to(device)
            pred = model(x)
            save_predictions(x, pred, box, page)

In [6]:
def save_predictions(x_s, preds, boxes, page):
    """Save predictions for analysis."""
    preds = preds.cpu().numpy().copy()
    boxes = boxes.cpu().numpy().copy()

    for pred, box in zip(preds, boxes):
        char = IDX_TO_CHAR[pred.argmax(0)]
        bbox = BBox(*box)
        page[bbox.top].append(Char(char, bbox))

In [7]:
def test():
    """Test the neural net."""
    model = get_model(MODEL_ARCH)
    load_model_state(MODEL_DIR, MODEL_STATE, model)

    device = torch.device(DEVICE)
    model.to(DEVICE)

    loader = get_loader()

    page = defaultdict(list)
    batches(model, device, loader, page)

    return page

In [8]:
sheet = test()

In [9]:
for chars in sheet.values():
    line = []
    prev = None
    for curr in chars:
        if prev and curr.bbox.left - prev.bbox.right > 40:
            line.append('\t')
        line.append(curr.char)
        prev = curr
    print(''.join(line))

-	---	==	===
NN	NN	OOOOOOOOOOOO	NN	NN	PPPPPPPPPPP	EEEEEEEEEEEE	RRRRRRRRRRR	CCCCCCCCCC
NNN	NN	OOOOOOOOOOOO	NNN	NN	PPPPP9PPPPPP	EEEFEEEEEFEF	RRRRRRRRRRRR	CCCCCCCCCCCC
NNNN	NN	OO	OO	NNNN	NN	PP	PP	EE	RR	RR	CC	CC
NN	NN	NN	OO	OC	NN	NN	NN	PP	PP	EE	RR	RR	CC
NN	NN	NN	OO	OO	NN	NN	NN	PP	PP	EF	RR	RR	CC
NN	NN	NN	OO	OO	NN	NN	NN	PPPPPPPPPPPP	EEEEEEEE	RRRRRRRRRRRR	CC
NN	NN	NN	OO	OO	NN	NN	NN	PPPPPPPPPPP	EEEEEEEE	RRRRRRRRRRR	CC
NN	NN	NN	OO	OO	NN	NN	NN	PP	EE	&	RR	RR	CC
NN	NNNN	OO	OO	NN	NNNN	PP	EF	RR	RR	CC
NN	NNN	OO	OO	NN	NNN	PP	E6	PR	RR	CC	CC
NN	NN	COOOO6OO0O00	NN	NN	PP	EEEEEEEEEEEE	RR	RR	CCCCCCCCCCCC
NN	N	OOOOOOOOOOOO	NN	N	PP	EEEEEEEEEEEE	RR	RR	CCCCCCCCCC
JJJJJJJJJJ	6666666666	11	6666666666	777777777777	AAAAAAAAAA
JJJJJJJJJJ	666666666666	111	666666666666	77777777777	AAAAAAAAAAAA
JJ	66	66	1111	66	66	77	77	AA	AA
JJ	66	11	66	77	AA	AA
JJ	66	11	66	77	AA	AA
JJ	66666666666	13	66666666666	77	AAAAAAAAAAAA
JJ	666666666666	11	666666666666	77	AAAAAAAAAAAA
JJ	66	66	11	66	66	77	AA	AA
JJ	JJ	66	66	11	66	66	77	AA	AA
JJ	