In [1]:
#|default_exp conv

# Convolutions

In [22]:
#|export
import torch
import torch.nn.functional as F
from torch import nn

from torch.utils.data import default_collate
from typing import Mapping

from miniai.training import *
from miniai.datasets import *

In [23]:
import pickle,gzip,math,os,time,shutil,torch,matplotlib as mpl, numpy as np
import pandas as pd,matplotlib.pyplot as plt
from pathlib import Path
from torch import tensor

from torch.utils.data import DataLoader
from typing import Mapping

In [4]:
mpl.rcParams['image.cmap'] = 'gray'

In [6]:
path_data = Path('../data')
path_gz = path_data/'mnist.pkl.gz'
with gzip.open(path_gz, 'rb') as f: ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
x_train, y_train, x_valid, y_valid = map(tensor, [x_train, y_train, x_valid, y_valid])

## Creating CNN

In [11]:
x_imgs = x_train.view(-1, 28, 28)
xv_imgs = x_valid.view(-1, 28, 28)

In [12]:
xb = x_imgs[:16][:, None]
xb.shape

torch.Size([16, 1, 28, 28])

In [7]:
n, m  = x_train.shape
c = y_train.max() + 1
nh = 50

In [8]:
model = nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, 10))

In [9]:
broken_cnn = nn.Sequential(
    nn.Conv2d(1, 30, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(30, 10, kernel_size=3, padding=1)
)

In [13]:
broken_cnn(xb).shape

torch.Size([16, 10, 28, 28])

In [14]:
#|export
def conv(ni, nf, ks=3, stride=2, act=True):
    res = nn.Conv2d(ni, nf, stride=stride, kernel_size=ks, padding=ks//2)
    if act: res = nn.Sequential(res, nn.ReLU())
    return res

In [45]:
simple_cnn = nn.Sequential(
    conv(1,   4),            # 14x14
    conv(4,   8),            # 7x7
    conv(8,  16),            # 4x4
    conv(16, 16),            # 2x2
    conv(16, 10, act=False), # 1x1
    nn.Flatten()
)

In [46]:
simple_cnn(xb).shape

torch.Size([16, 10])

In [47]:
x_imgs = x_train.view(-1, 1, 28, 28)
xv_imgs = x_valid.view(-1, 1, 28, 28)
train_ds, valid_ds = Dataset(x_imgs, y_train), Dataset(xv_imgs, y_valid)

In [48]:
#|export
def_device = 'cuda' if torch.cuda.is_available() else 'cpu'

def to_device(x, device=def_device):
    if isinstance(x, torch.Tensor): return x.to(device)
    if isinstance(x, Mapping): return {k:v.to(device) for k, v in x.items()}
    return type(x)(to_device(o, device) for o in x)

def collate_device(b): return to_device(default_collate(b))

In [51]:
from torch import optim

bs = 256
lr = 0.4
train_dl, valid_dl = get_dls(train_ds, valid_ds, bs, collate_fn=collate_device)
opt = optim.SGD(simple_cnn.parameters(), lr=lr)

In [52]:
loss, acc = fit(10, simple_cnn.to(def_device), F.cross_entropy, opt, train_dl, valid_dl)

0 2.0411722873687745 0.583200000667572
1 0.2505298261165619 0.921500000667572
2 0.15392718484401702 0.954300000667572
3 0.11525343742370606 0.9642999995231628
4 0.12151137762069703 0.9636999992370605
5 0.16605360834598543 0.947300000667572
6 0.10968214440345764 0.9685000007629394
7 0.08582237071990967 0.974299999332428
8 0.10325235047340393 0.9681999994277954
9 0.09661253645420075 0.9737999997138977


In [53]:
opt = optim.SGD(simple_cnn.parameters(), lr=lr/4)
loss, acc = fit(10, simple_cnn.to(def_device), F.cross_entropy, opt, train_dl, valid_dl)

0 0.07459476077556611 0.9782999996185303
1 0.07574831893444062 0.9788999997138977
2 0.08015529987812042 0.9770999996185302
3 0.07648587174415589 0.9781999997138977
4 0.07417190508842468 0.9782999996185303
5 0.07762444217205047 0.9784999996185303
6 0.0799566460609436 0.9780999997138977
7 0.07907814426422119 0.9790999997138977
8 0.07875795543193817 0.9784999996185303
9 0.07825491766929626 0.978699999332428


# Export -

In [54]:
import nbdev
nbdev.nbdev_export()