In [1]:
import torch
import torchvision
from torchvision.transforms import v2 as T
from torchvision.utils import draw_bounding_boxes
from torchvision.io import read_image
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from labeled_images import LabeledImages
import matplotlib.pyplot as plt

In [2]:
def Cnn(num_classes): 
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    return model

In [3]:
def get_transform(train):
    transforms = []
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    transforms.append(T.ToDtype(torch.float, scale=True))
    transforms.append(T.ToPureTensor())
    return T.Compose(transforms)

In [6]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")
dataset = LabeledImages('data/imgs', 'data/jsons', get_transform(train=True))

max_unique_labels = 0

for idx in range(len(dataset)):
    _, target = dataset[idx]
    num_unique_labels = len(target["labels_text"])
    if max_unique_labels < num_unique_labels:
        max_unique_labels = num_unique_labels
        unique_labels = target["labels_text"]

Unique_Labels: ['TextBlock', 'TextRun', 'Section', 'Field', 'Widget', 'ChoiceGroup', 'ChoiceField', 'ChoiceGroupTitle', 'SectionTitle', 'Header', 'HeaderTitle', 'Image', 'Footer']


In [5]:
from engine import train_one_epoch, evaluate
import utils

device = torch.device('cpu')

num_classes = 14
dataset = LabeledImages('data/imgs', 'data/jsons', get_transform(train=True))
dataset_test = LabeledImages('data/imgs', 'data/jsons', get_transform(train=False))

indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=True,
    num_workers=4,
    collate_fn=utils.collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    collate_fn=utils.collate_fn
)

model = Cnn(num_classes)

model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=3,
    gamma=0.1
)

num_epochs = 1

for epoch in range(num_epochs):
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    lr_scheduler.step()
    evaluate(model, data_loader_test, device=device)

Epoch: [0]  [0/1]  eta: 0:00:10  lr: 0.000005  loss: 11.7759 (11.7759)  loss_classifier: 3.3211 (3.3211)  loss_box_reg: 0.4260 (0.4260)  loss_objectness: 7.0148 (7.0148)  loss_rpn_box_reg: 1.0141 (1.0141)  time: 10.4384  data: 4.0540
Epoch: [0] Total time: 0:00:15 (15.4445 s / it)
creating index...
index created!
Test:  [0/1]  eta: 0:00:10  model_time: 5.3577 (5.3577)  evaluator_time: 0.0094 (0.0094)  time: 10.3308  data: 4.9635
Test: Total time: 0:00:15 (15.3346 s / it)
Averaged stats: model_time: 5.3577 (5.3577)  evaluator_time: 0.0094 (0.0094)
Accumulating evaluation results...
DONE (t=0.03s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | m

In [None]:
image = read_image("fr00001.png")

eval_transform = get_transform(train=False)

model.eval()
with torch.no_grad():
    x = eval_transform(image)
    x = x[:3, ...].to(device)
    predictions = model([x, ])
    pred = predictions[0]


image = (255.0 * (image - image.min()) / (image.max() - image.min())).to(torch.uint8)
image = image[:3, ...]
unique_labels = ['TextBlock', 'TextRun', 'Section', 'Field', 'Widget', 'ChoiceGroup', 'ChoiceField', 'ChoiceGroupTitle', 'SectionTitle', 'Header', 'HeaderTitle', 'Image', 'Footer']
desired_labels = range(13)
pred_labels = []
pred_boxes = []
boxes_long = pred["boxes"].long()
for label, score, box in zip(pred["labels"], pred["scores"], boxes_long):
    if (label in desired_labels):
      pred_labels.append(f"{unique_labels[label]}: {score:.3f}")
      pred_boxes.append(box)

pred_boxes = torch.stack(pred_boxes)
output_image = draw_bounding_boxes(image, pred_boxes, pred_labels, colors="red")

plt.figure(figsize=(12, 12))
plt.imshow(output_image.permute(1, 2, 0))