In [1]:
import sys
sys.path.append('..')

In [2]:
import argparse
import logging
import textwrap
from collections import defaultdict, namedtuple
from datetime import date
from os import makedirs
from pathlib import Path
from random import seed

import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from tqdm import tqdm

from allometry.const import CHARS, DATA_DIR, CLASS_TO_CHAR
from allometry.allometry_sheet import AllometrySheet
from allometry.model_util import load_state, get_model
from allometry.util import finished, started

In [3]:
STATE_DIR = DATA_DIR / 'state'
STATE = 'best_resnext101_2021-04-05_c.pth'
DEVICE = 'cuda:0'
BATCH_SIZE = 32
WORKERS = 4
MODEL = 'resnext101'

SHEET = DATA_DIR / 'allometry_sheets'
SHEET = SHEET / 'Biomass_Fish_Families_FamilyProgram'
SHEET = SHEET / '00004.tif'

ROTATE = 90

In [4]:
BBox = namedtuple('BBox', 'left top right bottom')
Char = namedtuple('Char', 'char bbox')

In [5]:
def get_loader():
    """Get the data loader."""
    dataset = AllometrySheet(SHEET, rotate=ROTATE)
    return DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=WORKERS)

In [6]:
def batches(model, device, loader, page):
    """Test the model."""
    model.eval()
    with torch.no_grad():
        for x, box in loader:
            x = x.to(device)
            pred = model(x)
            save_predictions(x, pred, box, page)

In [7]:
def save_predictions(x_s, preds, boxes, page):
    """Save predictions for analysis."""
    preds = preds.cpu().numpy().copy()
    lefts = boxes.left.cpu().numpy().copy()
    tops = boxes.top.cpu().numpy().copy()
    rights = boxes.right.cpu().numpy().copy()
    bottoms = boxes.bottom.cpu().numpy().copy()

    for pred, left, top, right, bottom in zip(preds, lefts, tops, rights, bottoms):
        char = CLASS_TO_CHAR[pred.argmax(0)]
        bbox = BBox(left, top, right, bottom)
        page[top].append(Char(char, bbox))

In [8]:
def test():
    """Test the neural net."""
    model = get_model(MODEL)
    load_state(STATE_DIR, STATE, model)

    device = torch.device(DEVICE)
    model.to(DEVICE)

    loader = get_loader()

    page = defaultdict(list)
    batches(model, device, loader, page)

    return page

In [9]:
page = test()

In [10]:
for chars in page.values():
    line = []
    prev = None
    for curr in chars:
        if prev and curr.bbox.left - prev.bbox.right > 40:
            line.append('\t')
        line.append(curr.char)
        prev = curr
    print(''.join(line))

--------------------------------------------------------
S	T	A	T	I	S	T	I	C	A	L	A	N	A	L	Y	S	I	S	S	Y	S	T	E	M	5
14$52	WEDNESDAY.	DECEMBER	19.	1979
GENERAL	LINEAR	MODELS	PROCEDURE
DEPENDENT	VARIA9LE$	BIOLOG
O9SERVATION	OHSERVED	PREDICTED	RES(DUAL	LOWER	95%	CL	UPPER	95%	CL
VALUF	VALUE	IND1VIDUAL	INDIVIDUAL
48	-0.41797664	-0.24469947	-0.17724717	-0.69556891	0.20618997
49	-1.05407930	-0.97691026	-0.07712904	-1.42659930	-0.52722122
50	-1.2471O357	-1.10989618	-0.13728739	-1.55993934	-0.65985303
51	-1.67264408	-1.568215Q1	-0.06442907	-2.02081407	-1.11561594
52	-0.55595520	-0.49205337	-0.06J90I83	-C.94193674	-0.0421700Q
53	0.12483015	0.43249837	-0.3O766822	-0.02416529	0.88916203
54	-0.$1745288	-0.11294318	-Q.00Q50970	-0.56459698	0.37871262
55	-Q.68824614	-0.66932736	-0.01891878	-1.118O6921	-0.2197O551
56	-.6O205999	-0.60340994	0.Q0134995	-1.05304239	-0.15377748
57	-0.60205999	-0.65277259	0.05071260	-1.10233315	-0.20321204
58	-0.63827216	-0.61456107	-0.023711Q9	-1.06417517	-0.16494697
59	-1.761953