# Training a model

> Fill in a module description here

In [None]:
#| default_exp basic

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#| export
dataset = MNIST('../data', download=True, transform=transforms.ToTensor())

In [None]:
! ls ../data/MNIST/raw

t10k-images-idx3-ubyte	   train-images-idx3-ubyte
t10k-images-idx3-ubyte.gz  train-images-idx3-ubyte.gz
t10k-labels-idx1-ubyte	   train-labels-idx1-ubyte
t10k-labels-idx1-ubyte.gz  train-labels-idx1-ubyte.gz


In [None]:
type(dataset)

torchvision.datasets.mnist.MNIST

In [None]:
#| export
train_loader = DataLoader(dataset, batch_size=10)

In [None]:
#| export
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Flatten(), nn.Linear(28*28, 64), nn.ReLU(), nn.Linear(64, 3))
        
    def forward(self, x):
        return self.l1(x)

In [None]:
#| export
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28*28))
        
    def forward(self, x):
        return self.l1(x)

Get a batch:

In [None]:
xb, yb = next(iter(train_loader))

In [None]:
xb.shape, yb.shape

(torch.Size([10, 1, 28, 28]), torch.Size([10]))

In [None]:
encoder_acts = Encoder()(xb)
encoder_acts.shape

torch.Size([10, 3])

In [None]:
decoder_acts = Decoder()(encoder_acts)
decoder_acts.shape

torch.Size([10, 784])

In [None]:
out_shape = xb.shape
out_shape

torch.Size([10, 1, 28, 28])

In [None]:
decoder_acts.view(out_shape).shape

torch.Size([10, 1, 28, 28])

In [None]:
loss = F.mse_loss(decoder_acts.view(out_shape), xb)
loss

tensor(0.1172, grad_fn=<MseLossBackward0>)

In [None]:
#| export
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch 
        out_shape = x.shape
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat.view(out_shape), x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [None]:

# model
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# train model
trainer = pl.Trainer(max_epochs=2)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Missing logger folder: /mnt/data/lightning-level-up/nbs/lightning_logs

  | Name    | Type    | Params
------------------------------------
0 | encoder | Encoder | 50.4 K
1 | decoder | Decoder | 51.2 K
------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6000/6000 [01:12<00:00, 82.30it/s, loss=0.0396, v_num=0]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6000/6000 [01:12<00:00, 82.28it/s, loss=0.0396, v_num=0]


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()