In [1]:
import sys
sys.path.append('..')

In [2]:
import argparse
import logging
import textwrap
from collections import defaultdict, namedtuple
from datetime import date
from os import makedirs
from pathlib import Path
from random import seed

import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from tqdm import tqdm

from allometry.const import CHARS, DATA_DIR, CLASS_TO_CHAR
from allometry.allometry_sheet import AllometrySheet
from allometry.model_util import load_model_state, get_model
from allometry.util import finished, started

In [3]:
MODEL_DIR = DATA_DIR / 'model'
MODEL_STATE = 'best_resnet152_2021-04-06_c.pth'
DEVICE = 'cuda:0'
BATCH_SIZE = 32
WORKERS = 4
MODEL_ARCH = 'resnet152'

SHEET = DATA_DIR / 'allometry_sheets'
SHEET = SHEET / 'Biomass_Fish_Families_FamilyProgram'
SHEET = SHEET / '00004.tif'

ROTATE = 90

In [4]:
BBox = namedtuple('BBox', 'left top right bottom')
Char = namedtuple('Char', 'char bbox')

In [5]:
def get_loader():
    """Get the data loader."""
    dataset = AllometrySheet(SHEET, rotate=ROTATE)
    return DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=WORKERS)

In [6]:
def batches(model, device, loader, page):
    """Test the model."""
    model.eval()
    with torch.no_grad():
        for x, box in loader:
            x = x.to(device)
            pred = model(x)
            save_predictions(x, pred, box, page)

In [7]:
def save_predictions(x_s, preds, boxes, page):
    """Save predictions for analysis."""
    preds = preds.cpu().numpy().copy()
    lefts = boxes.left.cpu().numpy().copy()
    tops = boxes.top.cpu().numpy().copy()
    rights = boxes.right.cpu().numpy().copy()
    bottoms = boxes.bottom.cpu().numpy().copy()

    for pred, left, top, right, bottom in zip(preds, lefts, tops, rights, bottoms):
        char = CLASS_TO_CHAR[pred.argmax(0)]
        bbox = BBox(left, top, right, bottom)
        page[top].append(Char(char, bbox))

In [8]:
def test():
    """Test the neural net."""
    model = get_model(MODEL_ARCH)
    load_model_state(MODEL_DIR, MODEL_STATE, model)

    device = torch.device(DEVICE)
    model.to(DEVICE)

    loader = get_loader()

    page = defaultdict(list)
    batches(model, device, loader, page)

    return page

In [9]:
page = test()

In [10]:
for chars in page.values():
    line = []
    prev = None
    for curr in chars:
        if prev and curr.bbox.left - prev.bbox.right > 40:
            line.append('\t')
        line.append(curr.char)
        prev = curr
    print(''.join(line))

--------------------------------------------------------
S	T	A	T	I	S	T	I	C	A	L	A	N	A	L	Y	S	I	S	S	Y	S	T	E	M	5
14:52	WEDNESDAY.	DECEMBER	I9.	I979
GENENAL	LINEAR	MODELS	PROCEDURE
DEPENDENT	VARIA9LE:	BIOLOG
OBSERVATION	OBSERVED	PR5DICTED	R5SIDUAL	LOWER	95%	CL	UPPER	95%	CL
VALUF	VALUE	INDIVIDUAL	INDIVIOUAL
48	-O.4I793664	-O.24469947	-O.173247I7	-O.6955669I	O.2O6$6997
49	-1.O54O393O	-O.9769IO26	-O.077129O4	-I.4265993O	-O.52722I22
50	-I.247I8357	-I.1O969618	-O.I3726739	-I.55993934	-O.659853O3
51	-I.632644O8	-I.5662I5O1	-O.O64429O7	-2.O2O6I4O7	-I.1I56I594
52	-O.5559552O	-O.492O5337	-O.0639OI83	-O.94I93674	-O.042I7OOO
53	0.12483015	0.43249837	-0.3O766622	-O.O24I6529	O.689I62O3
54	-O.11345288	-O.1I2943I8	-O.OOO5097O	-O.56459698	O.J367I262
55	-O.68624614	-O.66932736	-0.0I89I678	-I.1186692I	-O.2I97655I
56	-.6O2O5999	-O.6O34O994	O.OOI34995	-I.O53O4239	-O.I5377748
57	-O.6O2O5999	-O.65277259	O.05O7I26O	-1.IO2333I5	-O.2O32I2O4
58	-O.638272I6	-0.6I4561O7	-O.O2371109	-1.O6417517	-O.I6494697
59	-I.76I953