In [1]:
import sys
sys.path.append('..')

In [2]:
import argparse
import logging
import textwrap
from collections import defaultdict, namedtuple
from datetime import date
from os import makedirs
from pathlib import Path
from random import seed

import numpy as np
import torch
import torch.optim as optim
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import models
from tqdm import tqdm

from allometry.const import CHARS, DATA_DIR, CLASS_TO_CHAR
from allometry.allometry_sheet import AllometrySheet
from allometry.util import finished, started

In [3]:
STATE = DATA_DIR / 'state' / 'best_resnet50_2021-04-03_a.pth'
DEVICE = 'cuda:0'
BATCH_SIZE = 16
WORKERS = 4
MODEL = 'resnet50'

SHEET = DATA_DIR / 'allometry_sheets'
SHEET = SHEET / 'Biomass_Fish_Families_FamilyProgram'
SHEET = SHEET / '00004.tif'

ROTATE = 90

In [4]:
BBox = namedtuple('BBox', 'left top right bottom')
Char = namedtuple('Char', 'char bbox')

In [5]:
def get_model(_):
    """Get the model to use."""
    model = models.resnet50()
    model.conv1 = nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.fc = nn.Linear(model.fc.in_features, len(CHARS))
    # print(model)
    return model

In [6]:
def get_loader():
    """Get the data loader."""
    dataset = AllometrySheet(SHEET, rotate=ROTATE)
    return DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=WORKERS)

In [7]:
def batches(model, device, loader, page):
    """Test the model."""
    model.eval()
    with torch.no_grad():
        for x, box in loader:
            x = x.to(device)
            pred = model(x)
            save_predictions(x, pred, box, page)

In [8]:
def save_predictions(x_s, preds, boxes, page):
    """Save predictions for analysis."""
    preds = preds.cpu().numpy().copy()
    lefts = boxes.left.cpu().numpy().copy()
    tops = boxes.top.cpu().numpy().copy()
    rights = boxes.right.cpu().numpy().copy()
    bottoms = boxes.bottom.cpu().numpy().copy()

    for pred, left, top, right, bottom in zip(preds, lefts, tops, rights, bottoms):
        char = CLASS_TO_CHAR[pred.argmax(0)]
        bbox = BBox(left, top, right, bottom)
        page[top].append(Char(char, bbox))

In [9]:
def test():
    """Test the neural net."""
    model = get_model(MODEL)
    state = torch.load(STATE)
    model.load_state_dict(state)

    device = torch.device(DEVICE)
    model.to(DEVICE)

    loader = get_loader()

    page = defaultdict(list)
    batches(model, device, loader, page)

    return page

In [10]:
page = test()

In [14]:
for chars in page.values():
    line = []
    prev = None
    for curr in chars:
        if prev and curr.bbox.left - prev.bbox.right > 40:
            line.append('\t')
        line.append(curr.char)
        prev = curr
    print(''.join(line))

-------------------------------------------------------------------	-
S	T	A	T	I	S	T	I	C	A	L	A	N	A	L	Y	S	I	S	S	Y	S	T	E	M	5
^	14:52	WEDNESDAY4	DECEMBER	190	19^79
GENERAL-	LINEAR	MODELS	PROCEDURE
DEPENDENT	VARIA9LE:	OIO1-OG
O8SERVATION	ONSERVED	PPEDICTED	RESIDL)AL	LOWER	95%	CL	1JPPER	95%	C^L
VALUE	VALUE	IN^DIVIDUAL	INDIVIDUAL
48	-0441793664	-0024469947	-0017324717	-0069556891	0420618997
49	--1405403930	-0497691026	-0407^712904	--1442659930	-0852722122
50	--1424718357	-1410989618	-04137287^39	-1455993934	-0465985303
51	-1863264408	--1456821501	-0406442907	--200208I4O7	-I811561594.
52	-0455^595520	-0449205337	-0806390183	-049-4193674	-040421^7000
53	0412483015	0443249837	-0430766822	-0402416529	0488916203
54	--0411345288	-0411294318	-0400050970	-0056459898	04338^712-62
55	-0868824614	-0466932736	-0401891878	-1411886921	-0421978551
56	4060205999	-0460340994	0400134995	-1405304239	--0415377748
57	-0460205999	--0065277259	0405071260	-1410233315	-0=20321204
58	-0463827216	-0461456107	-040237110