<a href="https://colab.research.google.com/github/satani99/fastai_22_2/blob/main/Learner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [17]:
!git clone https://github.com/fastai/course22p2.git
%cd course22p2

Cloning into 'course22p2'...
remote: Enumerating objects: 1176, done.[K
remote: Counting objects: 100% (1176/1176), done.[K
remote: Compressing objects: 100% (480/480), done.[K
remote: Total 1176 (delta 712), reused 1117 (delta 694), pack-reused 0[K
Receiving objects: 100% (1176/1176), 99.81 MiB | 24.81 MiB/s, done.
Resolving deltas: 100% (712/712), done.
/content/course22p2/course22p2


In [18]:
!pip install datasets



In [19]:
import math, torch, matplotlib.pyplot as plt
import fastcore.all as fc
from collections.abc import Mapping
from operator import attrgetter
from functools import partial
from copy import copy

from torch import optim
import torch.nn.functional as F

from miniai.conv import *

from fastprogress import progress_bar, master_bar

In [20]:
import matplotlib as mpl
import torchvision.transforms.functional as TF
from contextlib import contextmanager
from torch import nn, tensor
from datasets import load_dataset, load_dataset_builder
from miniai.conv import *
from miniai.datasets import *
import logging
from fastcore.test import test_close

In [21]:
torch.set_printoptions(precision=2, linewidth=140, sci_mode=False)
torch.manual_seed(42)
mpl.rcParams['image.cmap'] = 'gray'

In [22]:
logging.disable(logging.WARNING)

In [23]:
x, y = 'image', 'label'
name = 'fashion_mnist'
dsd = load_dataset(name)

In [24]:
@inplace
def transformi(b): b[x] = [torch.flatten(TF.to_tensor(o)) for o in b[x]]

In [25]:
bs = 1024
tds = dsd.with_transform(transformi)

In [26]:
dls = DataLoaders.from_dd(tds, bs, num_workers=4)
dt = dls.train
xb, yb = next(iter(dt))
xb.shape, yb[:10]



(torch.Size([1024, 784]), tensor([5, 7, 4, 7, 3, 8, 9, 5, 3, 1]))

In [31]:
class Learner:
  def __init__(self, model, dls, loss_func, lr, opt_func=optim.SGD): fc.store_attr()

  def one_batch(self):
    self.xb, self.yb = to_device(self.batch)
    self.preds = self.model(self.xb)
    self.loss = self.loss_func(self.preds, self.yb)
    if self.model.training:
      self.loss.backward()
      self.opt.step()
      self.opt.zero_grad()
    with torch.no_grad(): self.calc_stats()

  def calc_stats(self):
    acc = (self.preds.argmax(dim=1)==self.yb).float().sum()
    self.accs.append(acc)
    n = len(self.xb)
    self.losses.append(self.loss*n)
    self.ns.append(n)

  def one_epoch(self, train):
    self.model.training = train
    dl = self.dls.train if train else self.dls.valid
    for self.num, self.batch in enumerate(dl): self.one_batch()
    n = sum(self.ns)
    print(self.epoch, self.model.training, sum(self.losses).item()/n, sum(self.accs).item()/n)

  def fit(self, n_epochs):
    self.accs, self.losses, self.ns = [], [], []
    self.model.to(def_device)
    self.opt = self.opt_func(self.model.parameters(), self.lr)
    self.n_epochs = n_epochs
    for self.epoch in range(n_epochs):
      self.one_epoch(True)
      with torch.no_grad(): self.one_epoch(False)


In [32]:
m, nh = 28*28, 50
model = nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, 10))

In [33]:
learn = Learner(model, dls, F.cross_entropy, lr=0.2)
learn.fit(1)



0 True 1.1871040364583334 0.5921166666666666
0 False 1.1323120535714286 0.6058714285714286


In [34]:
class CancelFitException(Exception): pass
class CancelBatchException(Exception): pass
class CancelEpochException(Exception): pass

In [35]:
class Callback(): order = 0

In [70]:
def run_cbs(cbs, method_nm, learn=None):
  for cb in sorted(cbs, key=attrgetter('order')):
    method = getattr(cb, method_nm, None)
    if method is not None: method(learn)

In [56]:
class CompletionCB(Callback):
  def before_fit(self, learn): self.count = 0
  def after_batch(self, learn): self.count += 1
  def after_fit(self, learn): print(f'Completed {self.count} batches')

In [41]:
cbs = [CompletionCB()]
run_cbs(cbs, 'before_fit')
run_cbs(cbs, 'after_batch')
run_cbs(cbs, 'after_fit')

Completed 1 batches


In [46]:
class Learner():
  def __init__(self, model, dls, loss_func, lr, cbs, opt_func=optim.SGD): fc.store_attr()

  def one_batch(self):
    self.preds = self.model(self.batch[0])
    self.loss = self.loss_func(self.preds, self.batch[1])
    if self.model.training:
      self.loss.backward()
      self.opt.step()
      self.opt.zero_grad()

  def one_epoch(self, train):
    self.model.train(train)
    self.dl = self.dls.train if train else self.dls.valid
    try:
      self.callback('before_epoch')
      for self.iter, self.batch in enumerate(self.dl):
        try:
          self.callback('before_batch')
          self.one_batch()
          self.callback('after_batch')
        except CancelBatchException: pass
      self.callback('after_epoch')
    except CancelEpochException: pass

  def fit(self, n_epochs):
    self.n_epochs = n_epochs
    self.epochs = range(n_epochs)
    self.opt = self.opt_func(self.model.parameters(), self.lr)
    try:
      self.callback('before_fit')
      for self.epoch in self.epochs:
        self.one_epoch(True)
        self.one_epoch(False)
      self.callback('after_fit')
    except CancelFitException: pass

  def callback(self, method_nm): run_cbs(self.cbs, method_nm, self)

In [47]:
m, nh = 28*28, 50
def get_model(): return nn.Sequential(nn.Linear(m, nh), nn.ReLU(), nn.Linear(nh, 10))

In [48]:
model = get_model()
learn = Learner(model, dls, F.cross_entropy, lr=0.2, cbs=[CompletionCB()])
learn.fit(1)

Completed 64 batches


In [73]:
class SingleBatchCB(Callback):
  order = 1
  def after_batch(self, learn): raise CancelFitException()

In [74]:
learn = Learner(get_model(), dls, F.cross_entropy, lr=0.2, cbs=[CompletionCB(), SingleBatchCB()])
learn.fit(1)



In [75]:
class Metric:
  def __init__(self): self.reset()
  def reset(self): self.vals, self.ns = [], []
  def add(self, inp, targ=None, n=1):
    self.last = self.calc(inp, targ)
    self.vals.append(self.last)
    self.ns.append(n)
  @property
  def value(self):
    ns = tensor(self.ns)
    return (tensor(self.vals)*ns).sum()/ns.sum()
  def calc(self, inps, targs): return inps

In [76]:
class Accuracy(Metric):
  def calc(self, inps, targs): return (inps==targs).float().mean()

In [78]:
acc = Accuracy()
acc.add(tensor([0, 1, 2, 0, 1, 2]), tensor([0, 1, 1, 2, 1, 0]))
acc.add(tensor([1, 1, 2, 0, 1]), tensor([0, 1, 1, 2, 1]))
acc.value

tensor(0.45)

In [79]:
loss = Metric()
loss.add(0.6, n=32)
loss.add(0.9, n=2)
loss.value, round((0.6*32 + 0.9*2)/(32+2), 2)

(tensor(0.62), 0.62)

In [81]:
!pip install torcheval

Collecting torcheval
  Downloading torcheval-0.0.7-py3-none-any.whl (179 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m179.2/179.2 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torcheval
Successfully installed torcheval-0.0.7


In [82]:
from torcheval.metrics import MulticlassAccuracy, Mean

In [83]:
metric = MulticlassAccuracy()
metric.update(tensor([0, 2, 1, 3]), tensor([0, 1, 2, 3]))
metric.compute()

tensor(0.50)