In [None]:
import sys

sys.path.append('..')

In [1]:
from collections import defaultdict, namedtuple

import torch

from allometry.util import finished, started
from allometry.model_util import load_model_state, get_model
from allometry.allometry_sheet import AllometrySheet
from allometry.const import DATA_DIR, BBox
from allometry.characters import IDX_TO_CHAR
from torch.utils.data import DataLoader

In [2]:
MODEL_DIR = DATA_DIR / 'model'
MODEL_STATE = 'best_resnext101_2021-04-08_context_c.pth'
DEVICE = 'cuda:0'
BATCH_SIZE = 20
MODEL_ARCH = 'resnext101'
WORKERS = 4

SHEET = DATA_DIR / 'allometry_sheets'
SHEET = SHEET / 'Biomass_Fish_Families_FamilyProgram'
SHEET = SHEET / '00004.tif'

Char = namedtuple('Char', 'char bbox')

ROTATE = 90

In [3]:
def get_loader():
    """Get the data loader."""
    dataset = AllometrySheet(SHEET, rotate=ROTATE)
    return DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=WORKERS)

In [4]:
def batches(model, device, loader, page):
    """Test the model."""
    model.eval()
    with torch.no_grad():
        for x, box in loader:
            x = x.to(device)
            pred = model(x)
            save_predictions(x, pred, box, page)

In [5]:
def save_predictions(x_s, preds, boxes, page):
    """Save predictions for analysis."""
    preds = preds.cpu().numpy().copy()
    lefts = boxes.left.cpu().numpy().copy()
    tops = boxes.top.cpu().numpy().copy()
    rights = boxes.right.cpu().numpy().copy()
    bottoms = boxes.bottom.cpu().numpy().copy()

    for pred, left, top, right, bottom in zip(preds, lefts, tops, rights, bottoms):
        char = IDX_TO_CHAR[pred.argmax(0)]
        bbox = BBox(left, top, right, bottom)
        page[top].append(Char(char, bbox))

In [6]:
def test():
    """Test the neural net."""
    model = get_model(MODEL_ARCH)
    load_model_state(MODEL_DIR, MODEL_STATE, model)

    device = torch.device(DEVICE)
    model.to(DEVICE)

    loader = get_loader()

    page = defaultdict(list)
    batches(model, device, loader, page)

    return page

In [7]:
sheet = test()

In [8]:
for chars in sheet.values():
    line = []
    prev = None
    for curr in chars:
        if prev and curr.bbox.left - prev.bbox.right > 40:
            line.append('\t')
        line.append(curr.char)
        prev = curr
    print(''.join(line))

--------	-	---	--	----------	-	--------------	-----------=-
S	T	A	T	I	S	T	I	C	A	L	A	N	A	L	Y	S	I	S	S	Y	S	T	E	M	5
14:52	WEDNESDAY*	DECEMBER	19.	1979
GENERAL	LINEAR	MODELS	PROCEDURE
DEPEN5ENT	VARIA3LE:	BIOLOG
O8SERVATION	OBSERVED	PREDICTED	RESIDUAL	LOWER	95%	CL	UPPER	95%	CL
VALUF	VALUE	INDIVIDUAL	INDIVIDUAL
48	-0.41793664	-0.24468947	-0.17324717	-0.69556891	0.20618997
49	-1.05403930	-0.97691026	-0.07712904	-1.42659930	-0.52722122
50	-1.24718357	-1.10989618	-0.13728739	-1.55993934	-0.65985333
51	-1.63264408	-1:56821501	-0.06442907	-2.02081407	-1.11561594
52	-0.55595520	-0.49205337	-0.06390183	-0.94193674	-0.04217000
53	0.12483015	0.43249837	-0.30766822	-0.02416529	0.88916203
54	-0.11345288	-0.11294318	-0.00050970	-0.56459898	0.33871262
55	-0.68824614	-0.66932736	-0.01891878	-1.11886921	-0.21978551
56	-6.60205999	-0.60340994	0.00134995	-1.05304239	-0.15377748
57	-0.60205999	-0.65277259	0.05071260	-1.10233315	-0.20321204
58	-0.63827216	-0.61456107	-0.02371109	-1.06417517	-0.16494697
59	-1.76