In [None]:
#hide
# ! [ -e /content ] && pip install -Uqq fastbook
import fastbook
fastbook.setup_book()

In [None]:
batch_size = 1024

In [None]:
#hide
from fastai.vision.all import *
from fastbook import *

matplotlib.rc('image', cmap='Greys')

In [None]:
path = untar_data(URLs.MNIST_SAMPLE)

In [None]:
#hide
Path.BASE_PATH = path

In [None]:
threes = (path/'train'/'3').ls().sorted()
sevens = (path/'train'/'7').ls().sorted()
threes

(#6131) [Path('train/3/10.png'),Path('train/3/10000.png'),Path('train/3/10011.png'),Path('train/3/10031.png'),Path('train/3/10034.png'),Path('train/3/10042.png'),Path('train/3/10052.png'),Path('train/3/1007.png'),Path('train/3/10074.png'),Path('train/3/10091.png')...]

In [None]:
seven_tensors = [tensor(Image.open(o)) for o in sevens]
three_tensors = [tensor(Image.open(o)) for o in threes]
len(three_tensors),len(seven_tensors)

(6131, 6265)

In [None]:
stacked_sevens = torch.stack(seven_tensors).float()/255
stacked_threes = torch.stack(three_tensors).float()/255
stacked_threes.shape

torch.Size([6131, 28, 28])

In [None]:
train_x = torch.cat([stacked_threes, stacked_sevens]).view(-1, 28*28)
train_y = tensor([1]*len(threes) + [0]*len(sevens)).unsqueeze(1)
train_x.shape,train_y.shape

(torch.Size([12396, 784]), torch.Size([12396, 1]))

In [None]:
dset = list(zip(train_x,train_y))
x,y = dset[0]
x.shape,y

(torch.Size([784]), tensor([1]))

In [None]:
dl = DataLoader(dset, batch_size=batch_size)

In [None]:
valid_3_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'3').ls()])
valid_3_tens = valid_3_tens.float()/255
valid_7_tens = torch.stack([tensor(Image.open(o)) 
                            for o in (path/'valid'/'7').ls()])
valid_7_tens = valid_7_tens.float()/255
valid_3_tens.shape,valid_7_tens.shape

(torch.Size([1010, 28, 28]), torch.Size([1028, 28, 28]))

In [None]:
valid_x = torch.cat([valid_3_tens, valid_7_tens]).view(-1, 28*28)
valid_y = tensor([1]*len(valid_3_tens) + [0]*len(valid_7_tens)).unsqueeze(1)
valid_dset = list(zip(valid_x,valid_y))

In [None]:
valid_dl = DataLoader(valid_dset, batch_size=batch_size)

In [None]:
dls = DataLoaders(dl, valid_dl)
dls

<fastai.data.core.DataLoaders>

In [None]:
# dir(dls)

In [None]:
def mnist_loss(predictions, targets):
    predictions = predictions.sigmoid()
    return torch.where(targets==1, 1-predictions, predictions).mean()

def batch_accuracy(xb, yb):
    preds = xb.sigmoid()
    correct = (preds>0.5) == yb
    return correct.float().mean()

In [None]:
import torch
torch.cuda.is_available()

True

In [None]:
# from tqdm.auto import tqdm
import time
import pandas as pd
import ipywidgets as widgets
import IPython.display as dsp
# from IPython.display import HTML, display

class MyLearner:
    
    def __init__(self,
                 dls,
                 model: 'callable',
                 opt_func: 'callable',
                 metrics: 'callable',
                 loss_func: 'callable | None' = None,
                 lr: float = 0.001):
        self.dls = dls
        self.model = model
        # self.model = model.cuda()
        self.metrics = metrics
        self.loss_func = loss_func
        self.opt_func = opt_func
        self.lr = lr
        self.training_summary = pd.DataFrame(columns=['epoch', 'train_loss', 'valid_loss', 'metric', 'time'])
    
    def _validate_epoch(self, model, valid_dl):
        # batch_losses = [self.loss_func(model(xb), yb) for xb, yb in valid_dl]
        # batch_metrics = [self.metrics(model(xb), yb) for xb,yb in valid_dl]

        batch_losses = []
        batch_metrics = []
        for xb, yb in valid_dl:
            # xb = xb.cuda()
            # yb = yb.cuda()
            yhat = model(xb)
            batch_losses.append(self.loss_func(yhat, yb))
            batch_metrics.append(self.metrics(yhat, yb))                             

        return torch.stack(batch_losses).mean().item(), torch.stack(batch_metrics).mean().item()
    
    def debug(self):
        print(self.model.parameters())
    
    def fit(self, n_epoch: int = 10, lr: 'float | None' = None):
        if not lr:
            lr = self.lr
        
        # Initialize training progress display
        self.training_summary = self.training_summary[0:0]
        progress_bar = widgets.IntProgress(value=0, min=1, max=n_epoch+1, step=1, description=f'epoch [0 / {n_epoch}]')
        dsp.display(progress_bar, dsp.HTML(self.training_summary.to_html(index=False)))
        
        # Initialize optimizer
        params = self.model.parameters()
        optimizer = self.opt_func(params, lr=lr)
            
        # Training loop
        for i in range(n_epoch):
            # Train
            t0 = time.time()

            epoch_train_loss = 0.0
            num_batches = 0
            for xb, yb in self.dls.train_ds:
                # xb = xb.cuda()
                # yb = yb.cuda()
                num_batches += 1
                preds = self.model(xb)
                loss = self.loss_func(preds, yb)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                epoch_train_loss += loss.item()
            epoch_train_loss /= num_batches
            
            # Validation
            epoch_valid_loss, epoch_valid_metric = self._validate_epoch(self.model, self.dls.valid_ds)
            
            # Update training progress display
            progress_bar.value += 1
            progress_bar.description = f'epoch [{i+1} / {n_epoch}]'
            t1 = time.time()
            epoch_summary = pd.DataFrame([{'epoch': i, 'train_loss': epoch_train_loss, 'valid_loss': epoch_valid_loss, 'metric': epoch_valid_metric, 'time': t1-t0}])
            self.training_summary = pd.concat([self.training_summary, epoch_summary])
            dsp.clear_output()
            dsp.display(progress_bar, dsp.HTML(self.training_summary.to_html(index=False)))
            


            

In [None]:
simple_net = nn.Sequential(
    nn.Linear(28*28,30),
    nn.ReLU(),
    nn.Linear(30,1)
)

learner = MyLearner(dls, simple_net, opt_func=SGD, loss_func=mnist_loss, metrics=batch_accuracy)

# learner.debug()
learner.fit(100, 0.1)

IntProgress(value=101, description='epoch [100 / 100]', max=101, min=1)

epoch,train_loss,valid_loss,metric,time
0,0.049406,0.49525,0.504416,4.775875
1,0.013176,0.467802,0.518646,4.80308
2,0.002175,0.30972,0.691364,4.753184
3,0.001726,0.232236,0.770363,4.737833
4,0.001231,0.205233,0.797841,4.756465
5,0.001191,0.169007,0.831698,4.731095
6,0.001399,0.173928,0.823847,4.770639
7,0.001402,0.168797,0.831207,4.794704
8,0.001351,0.165645,0.836114,4.785752
9,0.001841,0.116815,0.881747,4.859257


In [None]:
out = widgets.Output()
with out:
    for i in range(10):
        print(i, 'Hello world!')
    # out
# out

In [None]:
?Learner

[0;31mInit signature:[0m
[0mLearner[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mdls[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m:[0m [0;34m'callable'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mloss_func[0m[0;34m:[0m [0;34m'callable | None'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mopt_func[0m[0;34m=[0m[0;34m<[0m[0mfunction[0m [0mAdam[0m [0mat[0m [0;36m0x7f1fefe194c0[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlr[0m[0;34m=[0m[0;36m0.001[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msplitter[0m[0;34m:[0m [0;34m'callable'[0m [0;34m=[0m [0;34m<[0m[0mfunction[0m [0mtrainable_params[0m [0mat[0m [0;36m0x7f1ff3764c10[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcbs[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmetrics[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpath[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[