In [1]:
import json
from pathlib import Path
from textwrap import dedent
import pdb

from tqdm import tqdm
from ipywidgets import interact
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib import patches, patheffects

import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

import transforms as T
from engine import train_one_epoch, evaluate
import utils

from label_babel_dataset import LabelBabelDataset

In [2]:
%matplotlib inline
%reload_ext autoreload
%autoreload 2

%load_ext watermark
%watermark -i -z -v -p torch,torchvision

EST 2019-12-16T11:18:26-05:00

CPython 3.7.5
IPython 7.10.2

torch 1.3.1
torchvision 0.4.2


In [3]:
DATA_DIR = Path('.') / 'data'
VAL_DIR = DATA_DIR / 'valid'
TRN_DIR = DATA_DIR / 'train'
MODEL_DIR = DATA_DIR / 'models'

TRN_CSV = DATA_DIR / 'train.csv'
VAL_CSV = DATA_DIR / 'valid.csv'

BOX = 'box'
CAT = 'category'
PATH = 'path'
CLASS = 'class'
SUB_ID = 'subject_id'

SEED = 23

DEVICE = torch.device('cuda')
torch.backends.cudnn.benchmark = True  # Optimizes cudnn

CATS = ['background', 'handwritten', 'typewritten']
CLASSES = len(CATS)

EPOCHS = 10  # 50
CHECKPOINT = 'checkpoint_{}.pth.tar'

In [4]:
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f12ed6a2df0>

In [5]:
trn_df = pd.read_csv(TRN_CSV, index_col='subject_id').reset_index()
val_df = pd.read_csv(VAL_CSV, index_col='subject_id').reset_index()

trn_df.head()

Unnamed: 0,subject_id,category,class,box,path,original,predicted_class,predicted_category,predicted_box,box_original
0,2995300,typewritten,2,"[231, 446, 368, 564]",data/train/2995300.jpg,data/images/2995300.jpg,2,typewritten,[228.21541 446.85278 369.32605 562.8112 ],"[924, 1784, 1472, 2256]"
1,4128323,typewritten,2,"[156, 322, 250, 382]",data/train/4128323.jpg,data/images/4128323.jpg,2,typewritten,[156.92148 321.9174 248.3404 381.2727 ],"[624, 1288, 1000, 1528]"
2,4128517,handwritten,1,"[155, 321, 248, 382]",data/train/4128517.jpg,data/images/4128517.jpg,2,typewritten,[157.88994 318.32904 246.69968 381.92596],"[620, 1284, 992, 1528]"
3,11783370,handwritten,1,"[552, 1225, 966, 1483]",data/train/11783370.jpg,data/images/11783370.jpg,2,typewritten,[ 545.6113 1229.2281 966.97504 1483.6238 ],"[2208, 4900, 3864, 5932]"
4,11782469,typewritten,2,"[612, 1182, 977, 1464]",data/train/11782469.jpg,data/images/11782469.jpg,2,typewritten,[ 610.1748 1182.5092 974.12897 1460.3987 ],"[2448, 4728, 3908, 5856]"


In [6]:
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

in_features = model.roi_heads.box_predictor.cls_score.in_features

model.roi_heads.box_predictor = FastRCNNPredictor(in_features, CLASSES)

In [7]:
def get_transform(train):
    transforms = [T.ToTensor()]
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
        transforms.append(T.RandomVerticalFlip(0.5))
        transforms.append(T.RandomRotate())
        transforms.append(T.ColorJitter(
            brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25))
    return T.Compose(transforms)

In [8]:
trn_dataset = LabelBabelDataset(trn_df, get_transform(train=True))
val_dataset = LabelBabelDataset(val_df, get_transform(train=False))

In [9]:
trn_loader = torch.utils.data.DataLoader(
    trn_dataset, batch_size=1, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=1, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

In [10]:
model.to(DEVICE)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(
    params, lr=0.005, momentum=0.9, weight_decay=0.0005)

lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer, step_size=5, gamma=0.1)

RuntimeError: cuda runtime error (999) : unknown error at /pytorch/aten/src/THC/THCGeneral.cpp:50

In [22]:
last_checkpoint = sorted(MODEL_DIR.glob(CHECKPOINT.format('*')))[-1]

state = torch.load(last_checkpoint)

first_epoch = state['epoch']
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

last_checkpoint

PosixPath('data/models/checkpoint_005.pth.tar')

In [23]:
def save_state(epoch, model, optimizer):
    state_path = MODEL_DIR / CHECKPOINT.format(str(epoch).zfill(3))
    if state_path.exists():
        return
    state = {
        'epoch': epoch + 1,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(state, state_path)

In [24]:
def train_model():
    for epoch in range(first_epoch, EPOCHS):
        train_one_epoch(
            model, optimizer, trn_loader, DEVICE, epoch, print_freq=100)
        lr_scheduler.step()
        evaluate(model, val_loader, device=DEVICE)
        if state % 10 == 0:
            save_state(epoch, model, optimizer)

    save_state(epoch, model, optimizer)


# train_model()