In [None]:
import os, sys
import copy as copyroot
import pandas as pd
from IPython.display import display
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler

In [None]:
my_env = os.environ.get('USER', 'KAGGLE')
b_kaggle = (my_env == 'KAGGLE')
b_gcp    = (my_env == 'jupyter')
b_local  = (my_env == 'user')

if b_kaggle:
    from mnist_helpers import build_df, eda_fig_1

In [None]:
from fastai2.basics import *
from fastai2.vision.all import *

### Setup

In [None]:
path = untar_data(URLs.MNIST_TINY)
df = build_df(path)
df.head(2)

In [None]:
y_names = [
    'point_topleft_x', 
    'point_topleft_y',
    'point_center_x',
    'point_center_y'
    ]

db =   DataBlock(blocks=(ImageBlock(cls=PILImageBW), 
                         PointBlock), 
                splitter=RandomSplitter(seed=0),
                get_x=ColReader('fn', pref=path),
                )

db_1_topleft = copyroot.deepcopy(db)
db_1_center  = copyroot.deepcopy(db)
db_2         = copyroot.deepcopy(db)

def set_get_y(db, cr):
    db.get_y = cr
    db.getters[db.n_inp:] = L(db.get_y)

set_get_y( db_1_topleft, ColReader(y_names[:2]) )
set_get_y( db_1_center,  ColReader(y_names[2:]) )
set_get_y( db_2,         ColReader(y_names) )

dl_1_topleft = db_1_topleft.dataloaders(df)
dl_1_center  = db_1_center.dataloaders(df)
dl_2         = db_2.dataloaders(df)

### Fit CenterPoint

In [None]:
# !mkdir assets
# !mkdir models

In [None]:
b_new_fit = True

history_fn = 'pt3_center_1.csv'
model_fn   = 'pt3_center_2'

learn = cnn_learner(dl_1_center, 
                    resnet18, 
                    pretrained=True, 
                    metrics=[mae, R2Score()],
                    cbs=CSVLogger(history_fn),
                    y_range=(-1.,1.),
                   )
if b_new_fit:
    
    set_seed(17)
    with learn.no_logging():
        learn.fine_tune(50)
        
    learn.save(model_fn)
    
if not(b_new_fit):
    
    learn.load(model_fn)