In [1]:
import sys

sys.path.append('..')

In [2]:
from collections import defaultdict, namedtuple

import torch

from allometry.util import finished, started
from allometry.model_util import load_model_state, get_model
from allometry.allometry_sheet import AllometrySheet
from allometry.const import DATA_DIR, BBox
from allometry.characters import IDX_TO_CHAR
from torch.utils.data import DataLoader

In [3]:
MODEL_DIR = DATA_DIR / 'model'
MODEL_STATE = 'resnext101_2021-04-08_context_c.pth'
DEVICE = 'cuda:0'
BATCH_SIZE = 20
MODEL_ARCH = 'resnext101'
WORKERS = 4

SHEET = DATA_DIR / 'allometry_sheets'
SHEET = SHEET / 'Biomass_Fish_Families_FamilyProgram'
SHEET = SHEET / '00001.tif'

Char = namedtuple('Char', 'char bbox')

ROTATE = 90

In [4]:
def get_loader():
    """Get the data loader."""
    dataset = AllometrySheet(SHEET, rotate=ROTATE)
    return DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=WORKERS)

In [5]:
def batches(model, device, loader, page):
    """Test the model."""
    model.eval()
    with torch.no_grad():
        for x, box in loader:
            x = x.to(device)
            pred = model(x)
            save_predictions(pred, box, page)

In [6]:
def save_predictions(preds, boxes, page):
    """Save predictions for analysis."""
    preds = preds.cpu().numpy().copy()
    boxes = boxes.cpu().numpy().copy()

    for pred, box in zip(preds, boxes):
        char = IDX_TO_CHAR[pred.argmax(0)]
        bbox = BBox(*box)
        page[bbox.top].append(Char(char, bbox))

In [7]:
def test():
    """Test the neural net."""
    model = get_model(MODEL_ARCH)
    load_model_state(MODEL_DIR / MODEL_STATE, model)

    device = torch.device(DEVICE)
    model.to(DEVICE)

    loader = get_loader()

    page = defaultdict(list)
    batches(model, device, loader, page)

    return page

In [8]:
sheet = test()

In [9]:
# for chars in sheet.values():
#     line = []
#     prev = None
#     for curr in chars:
#         if prev and curr.bbox.left - prev.bbox.right > 40:
#             line.append('\t')
#         line.append(curr.char)
#         prev = curr
#     print(''.join(line))