In [1]:
import sys
sys.path.append('..')

In [2]:
import argparse
import logging
import textwrap
from collections import defaultdict, namedtuple
from datetime import date
from os import makedirs
from pathlib import Path
from random import seed

import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from tqdm import tqdm

from allometry.const import CHARS, DATA_DIR, CLASS_TO_CHAR
from allometry.allometry_sheet import AllometrySheet
from allometry.model_util import load_model_state, get_model
from allometry.util import finished, started

In [3]:
MODEL_DIR = DATA_DIR / 'model'
MODEL_STATE = 'best_resnext101_2021-04-06_a.pth'
DEVICE = 'cuda:0'
BATCH_SIZE = 32
WORKERS = 4
MODEL_ARCH = 'resnext101'

SHEET = DATA_DIR / 'allometry_sheets'
SHEET = SHEET / 'Biomass_Fish_Families_FamilyProgram'
SHEET = SHEET / '00004.tif'

ROTATE = 90

In [4]:
BBox = namedtuple('BBox', 'left top right bottom')
Char = namedtuple('Char', 'char bbox')

In [5]:
def get_loader():
    """Get the data loader."""
    dataset = AllometrySheet(SHEET, rotate=ROTATE)
    return DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=WORKERS)

In [6]:
def batches(model, device, loader, page):
    """Test the model."""
    model.eval()
    with torch.no_grad():
        for x, box in loader:
            x = x.to(device)
            pred = model(x)
            save_predictions(x, pred, box, page)

In [7]:
def save_predictions(x_s, preds, boxes, page):
    """Save predictions for analysis."""
    preds = preds.cpu().numpy().copy()
    lefts = boxes.left.cpu().numpy().copy()
    tops = boxes.top.cpu().numpy().copy()
    rights = boxes.right.cpu().numpy().copy()
    bottoms = boxes.bottom.cpu().numpy().copy()

    for pred, left, top, right, bottom in zip(preds, lefts, tops, rights, bottoms):
        char = CLASS_TO_CHAR[pred.argmax(0)]
        bbox = BBox(left, top, right, bottom)
        page[top].append(Char(char, bbox))

In [8]:
def test():
    """Test the neural net."""
    model = get_model(MODEL_ARCH)
    load_model_state(MODEL_DIR, MODEL_STATE, model)

    device = torch.device(DEVICE)
    model.to(DEVICE)

    loader = get_loader()

    page = defaultdict(list)
    batches(model, device, loader, page)

    return page

In [9]:
page = test()

In [10]:
for chars in page.values():
    line = []
    prev = None
    for curr in chars:
        if prev and curr.bbox.left - prev.bbox.right > 40:
            line.append('\t')
        line.append(curr.char)
        prev = curr
    print(''.join(line))

-------------------------------------------=---------===
5	T	A	T	I	G	Y	I	C	A	L	A	N	A	L	Y	G	:	G	G	Y	8	T	Q	M	Q
I4:D8	WGQNEOQAY.	QE<ENDER	I9.	X979
GENBRAL	LINEAR	MQQELG	PRQCEQURE
OEPENQ6NT	VARIA3LE$	OIQLQG
QOGCRVATIQN	QBSERVQQ	PR<QICT6O	R5G:QUAL	LQWER	9UN	CL	UPPCR	95N	CL
VALUF	VALQE	INOIVIQUAL	INOIV:QQAL
48	.Q.4I797OQ4	-Q.2446O94F	.Q.I7Q.47I7	-Q.O95QGO9I	Q.8QQIOQ97
4Q	.I.Q54Q393Q	.Q.97Q9IQ86	-Q.QFFI89Q4	-I.QQCD9Q3Q	.Q.5QF88I8.
5Q	.I.847IO357	.X.1QQO9Q2O	-Q.I378O73Q	-I.U59939D4	-Q.OO9OQ3Q3
5I	.I.O3C644QO	.I.56OZIUQ1	.Q.QO4Q8QQ7	.Q.Q8QOI4Q7	-I.IIDOID94
U8	.Q.5659558Q	-Q.49QQN337	-Q.QGJ9QIO3	.Q.941936FA	-Q.Q4QI7QQQ
DJ	Q.I84O30I5	Q.43R49O37	-Q.3Q7QOO8Q	-Q.QA4IQDA9	Q.O89IOBQ3
D4	-Q.$I345ROO	-Q.:IRQ43IO	-Q.QQQ5Q97C	.Q.5O459OQO	Q.336F$868
55	.Q.QOO84QI4	-Q.OOQ3873O	-Q.QIO9IOTO	-I.IIOOO98I	-Q.QI97ODUI
5Q	M.OQ8Q5999	-Q.QQ34Q99Q	Q.QQ1349Q5	-I.OU3Q4B39	-Q.IDW7F74O
57	-Q.OQ8Q5QQQ	-Q.O5877BO9	Q.QUQFI86Q	-1.1Q83331D	.Q.8Q3RI8Q4
5O	.Q.O3O47QIO	-Q.OI4DOIQ7	-Q.QR37I1Q9	.$.QO427517	.Q.IO4946Q7
OQ	.I.76395O