In [1]:
#hide
from fastai2.vision.all import *
from utils import *
from fastai2.basics import *
from fastai2.callback.all import *

matplotlib.rc('image', cmap='Greys')

# Under the hood: training a digit classifier

In [2]:
#Recuperação da base completa
path = untar_data(URLs.MNIST)

In [3]:
#Vetor contendo o path de cada imagem
items= get_image_files(path)

In [4]:
# O metodo GrandparentSplitter divide a base considerando o nome dos diretorios passados como parametro. 
# Note que informamos os diretorios training e testing, que sao como os nossos dados estao divididos
# apos o download do repositorio do fastai
splits = GrandparentSplitter(train_name='training', valid_name='testing')
# O vetor items, que possuem os caminhos de todas as imagens, é passado como parametro para efetuar a divisao
# em dados de treinamento e teste
splits = splits(items)

In [5]:
# A classe Datasets cria linhas input e output a partir de items. Neste caso, o input é a imagem e o output é classe a qual a imagem pertence (de 0 a 9)  
# Note que Datasets, para cada linha, aplica a transformações em tfms (criacao da imagem, rotula a imagem com o seu nome de diretorio e tranformacao de string em id) e
# faz essa operacao apenas para os diretorios de split
dataset = Datasets(items, tfms=[[PILImageBW.create],[parent_label, Categorize]], splits=splits)

In [6]:
# Lista de transformacoes a serem aplicadas no dataloader 
# RandomCrop - corta a imagem randomicamente (data augmentation) com tamanho 28
# ToTensor() - tranforma para tensor
# Normalize() - normaliza para cada batch
tfms = [ToTensor(), RandomCrop(size=28)]
gpu_tfms = [IntToFloatTensor(), Normalize()]

In [7]:
# Cria o dataloader com batch size de 128 e com as transformações especificadas
dls = dataset.dataloaders(bs=128, after_item=tfms, after_batch=gpu_tfms)

In [8]:
# Camada de convolucao
def conv2(ni, nf):
    return ConvLayer(ni, nf, stride=2)

In [9]:
# Resnet block
class ResBlock(nn.Module):
    def __init__(self, nf):
        super().__init__()
        self.conv1 = ConvLayer(nf,nf)
        self.conv2 = ConvLayer(nf,nf)
        
    def forward(self, x): return x + self.conv2(self.conv1(x))

In [10]:
# Funcao que realiza uma convolucao e uma passagem pela Resnet 
def conv_and_res(ni,nf): return nn.Sequential(conv2(ni, nf), ResBlock(nf))

In [11]:
# Rede neural 
net = nn.Sequential(
    conv_and_res(1, 8),
    conv_and_res(8, 16),
    conv_and_res(16, 32),
    conv_and_res(32, 16),
    conv2(16, 10),
    Flatten()
)

In [13]:
# Realiza treinamento de 8 epocas 
# Acuracia de 99,4%
learn = Learner(dls, net, loss_func=CrossEntropyLossFlat(), metrics=accuracy)
learn.fit_one_cycle(10, lr_max=1e-1)

epoch,train_loss,valid_loss,accuracy,time
0,0.01846,0.040945,0.9893,00:23
1,0.053786,0.080649,0.9772,00:14
2,0.054771,0.056357,0.9818,00:14
3,0.047814,0.052457,0.9836,00:15
4,0.048431,0.031511,0.9901,00:14
5,0.033389,0.029736,0.991,00:14
6,0.028485,0.031852,0.9897,00:14
7,0.015166,0.022535,0.9933,00:14
8,0.009339,0.018067,0.994,00:14
9,0.005077,0.01744,0.9949,00:14


In [14]:
# Exporta modelo treinado
learn.export()

In [15]:
# importe das libs
from utils import *
from fastai2.vision.widgets import *
from pathlib import *

In [16]:
# resnet block. É necessario, na medida em que export.pkl o referencia
class ResBlock(nn.Module):
    def __init__(self, nf):
        super().__init__()
        self.conv1 = ConvLayer(nf,nf)
        self.conv2 = ConvLayer(nf,nf)
        
    def forward(self, x): return x + self.conv2(self.conv1(x))

In [17]:
#carrega modelo treinado 
path = Path()
learn_inf = load_learner(path/'export.pkl')

In [18]:
btn_upload = widgets.FileUpload()
out_pl = widgets.Output()
lbl_pred = widgets.Label()
btn_run = widgets.Button(description='Classify')
def on_click_classify(change):
    img = PILImageBW.create(btn_upload.data[-1])
    out_pl.clear_output()
    with out_pl: display(img.to_thumb(128,128))
    pred,pred_idx, probs = learn_inf.predict(img)
    lbl_pred.value = f'Prediction: {pred}; Probability: {probs[pred_idx]:.04f}'

btn_run.on_click(on_click_classify)

In [19]:
VBox([widgets.Label('Choose a handwritten image from 0 to 9'), 
      btn_upload, btn_run, out_pl, lbl_pred])

VBox(children=(Label(value='Choose a handwritten image from 0 to 9'), FileUpload(value={}, description='Upload…