In [1]:
import sys
sys.path.append('..')

In [2]:
import argparse
import logging
import textwrap
from collections import defaultdict, namedtuple
from datetime import date
from os import makedirs
from pathlib import Path
from random import seed

import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from tqdm import tqdm

from allometry.const import CHARS, DATA_DIR, CLASS_TO_CHAR
from allometry.allometry_sheet import AllometrySheet
from allometry.model_util import load_state, get_model
from allometry.util import finished, started

In [3]:
STATE_DIR = DATA_DIR / 'state'
STATE = 'best_resnet34_2021-04-04_a.pth'
DEVICE = 'cuda:0'
BATCH_SIZE = 32
WORKERS = 4
MODEL = 'resnet34'

SHEET = DATA_DIR / 'allometry_sheets'
SHEET = SHEET / 'Biomass_Fish_Families_FamilyProgram'
SHEET = SHEET / '00004.tif'

ROTATE = 90

In [4]:
BBox = namedtuple('BBox', 'left top right bottom')
Char = namedtuple('Char', 'char bbox')

In [5]:
def get_loader():
    """Get the data loader."""
    dataset = AllometrySheet(SHEET, rotate=ROTATE)
    return DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=WORKERS)

In [6]:
def batches(model, device, loader, page):
    """Test the model."""
    model.eval()
    with torch.no_grad():
        for x, box in loader:
            x = x.to(device)
            pred = model(x)
            save_predictions(x, pred, box, page)

In [7]:
def save_predictions(x_s, preds, boxes, page):
    """Save predictions for analysis."""
    preds = preds.cpu().numpy().copy()
    lefts = boxes.left.cpu().numpy().copy()
    tops = boxes.top.cpu().numpy().copy()
    rights = boxes.right.cpu().numpy().copy()
    bottoms = boxes.bottom.cpu().numpy().copy()

    for pred, left, top, right, bottom in zip(preds, lefts, tops, rights, bottoms):
        char = CLASS_TO_CHAR[pred.argmax(0)]
        bbox = BBox(left, top, right, bottom)
        page[top].append(Char(char, bbox))

In [8]:
def test():
    """Test the neural net."""
    model = get_model(MODEL)
    load_state(STATE_DIR, STATE, model)

    device = torch.device(DEVICE)
    model.to(DEVICE)

    loader = get_loader()

    page = defaultdict(list)
    batches(model, device, loader, page)

    return page

In [9]:
page = test()

In [10]:
for chars in page.values():
    line = []
    prev = None
    for curr in chars:
        if prev and curr.bbox.left - prev.bbox.right > 40:
            line.append('\t')
        line.append(curr.char)
        prev = curr
    print(''.join(line))

--------------------------------------------------------
S	T	A	T	I	S	T	I	C	A	L	A	N	A	L	Y	S	I	S	S	Y	S	T	E	M	5
14:52	MEDNESDAY0	DECEMBER	190	1979
GENERAL	LINEAR	MCDELS	PRLCEDURE
DEPENDENT	VARIA3LE$	BIOLOG
OOSERVATION	OBSERVED	PREDICTFD	RESIDUAL	LOMER	95%	CL	UPPER	95%	CL
V4LUF	VALUE	IND1VIDU4L	INDIVIDUAL
48	=08417936O4	-0624464947	-0617324717	-O859556891	062061B997
49	-16054O3930	-0697691026	-06O77129O4	-164265993O	=O652722122
50	-1624718357	-161O989618	-0813728739	-1655993934	-O6659853O3
51	-16532644O8	=16558215O1	-O8O64429O7	-26O2O814O7	-1811561594
52	=985559552O	-06492O5337	-O6O639O183	=O694193674	-O604217O00
53	0812483015	0843249837	-OG3O766622	-06O2416529	06O8916203
54	-061134528B	=O811294318	-O800O50970	-0656459898	O633671262
55	-0668O24614	-066693273U	-O601891878	-1811OO6921	=0821978551
56	465D2O5999	-06U034O994	060O134995	-18O53O4239	-O615377748
57	-0660205999	-O665277259	O6O5O7126O	=161O233315	=O8203212O4
5O	-0663827216	=06F1456107	-O6O23711O9	-16O6417517	=O616494697
59	-16761953