In [1]:
import torch
from torch.utils.data import DataLoader 
import torchvision
from torchvision import transforms
image_path = './'
transform = transforms.Compose([transforms.ToTensor()])
mnist_train_dataset = torchvision.datasets.MNIST(root=image_path, 
                                                 train=True, 
                                                 transform=transform, 
                                                 download=True)
mnist_test_dataset = torchvision.datasets.MNIST(root=image_path, 
                                                train=False, 
                                                transform=transform, 
                                                download=True)
batch_size = 64
torch.manual_seed(1)
train_dl = DataLoader(mnist_train_dataset, 
                      batch_size, shuffle=True)

In [2]:
import torch.nn as nn
hidden_units = [32, 16]
image_size = mnist_train_dataset[0][0].shape
input_size = image_size[0]*image_size[1]*image_size[2]
all_layers = [nn.Flatten()]
for hidden_unit in hidden_units:
    layer = nn.Linear(input_size, hidden_unit) 
    all_layers.append(layer)
    all_layers.append(nn.ReLU())
    input_size = hidden_unit
all_layers.append(nn.Linear(hidden_units[-1], 10))
model = nn.Sequential(*all_layers)
model

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=32, bias=True)
  (2): ReLU()
  (3): Linear(in_features=32, out_features=16, bias=True)
  (4): ReLU()
  (5): Linear(in_features=16, out_features=10, bias=True)
)

In [14]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
torch.manual_seed(1)
num_epochs = 20
for epoch in range(num_epochs): 
    accuracy_hist_train = 0 
    for x_batch, y_batch in train_dl: 
        pred = model(x_batch)
        loss = loss_fn(pred, y_batch)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        is_correct = (torch.argmax(pred, dim=1) == y_batch).float()
        accuracy_hist_train += is_correct.sum()
    accuracy_hist_train /= len(train_dl.dataset)
    print(f'Epoch {epoch} Accuracy ' f'{accuracy_hist_train:.4f}')
    

Epoch 0 Accuracy 0.8543
Epoch 1 Accuracy 0.9289
Epoch 2 Accuracy 0.9446
Epoch 3 Accuracy 0.9528
Epoch 4 Accuracy 0.9571
Epoch 5 Accuracy 0.9618
Epoch 6 Accuracy 0.9648
Epoch 7 Accuracy 0.9677
Epoch 8 Accuracy 0.9710
Epoch 9 Accuracy 0.9723
Epoch 10 Accuracy 0.9743
Epoch 11 Accuracy 0.9757
Epoch 12 Accuracy 0.9770
Epoch 13 Accuracy 0.9785
Epoch 14 Accuracy 0.9796
Epoch 15 Accuracy 0.9810
Epoch 16 Accuracy 0.9825
Epoch 17 Accuracy 0.9829
Epoch 18 Accuracy 0.9840
Epoch 19 Accuracy 0.9841


In [15]:
pred = model(mnist_test_dataset.data/255.)
is_correct = (
    torch.argmax(pred,dim=1) == mnist_test_dataset.targets).float()
print(f'Test accuracy: {is_correct.mean(): .4f}')


Test accuracy:  0.9672


In [14]:
import pytorch_lightning as pl
import torch 
import torch.nn as nn 

from torchmetrics import Accuracy
class MultiLayerPerceptron(pl.LightningModule):
    def __init__(self, image_shape=(1, 28, 28), hidden_units=(32, 16)):
        super().__init__()
        
        # new PL attributes:
        # update to book is task
        self.train_acc = Accuracy(task='multiclass', num_classes=10)
        self.valid_acc = Accuracy(task='multiclass', num_classes=10)
        self.test_acc = Accuracy(task='multiclass', num_classes=10)
        
        # Model similar to previous section:
        input_size = image_shape[0] * image_shape[1] * image_shape[2] 
        all_layers = [nn.Flatten()]
        for hidden_unit in hidden_units: 
            layer = nn.Linear(input_size, hidden_unit) 
            all_layers.append(layer) 
            all_layers.append(nn.ReLU()) 
            input_size = hidden_unit 
 
        all_layers.append(nn.Linear(hidden_units[-1], 10)) 
        self.model = nn.Sequential(*all_layers)

    def forward(self, x):
        x = self.model(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.train_acc.update(preds, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def training_epoch_end(self, outs):
        self.log("train_acc", self.train_acc.compute())
        self.train_acc.reset()
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.valid_acc.update(preds, y)
        self.log("valid_loss", loss, prog_bar=True)
        return loss
    
    def validation_epoch_end(self, outs):
        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
        self.valid_acc.reset()

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.test_acc.update(preds, y)
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", self.test_acc.compute(), prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer

In [15]:
from torch.utils.data import DataLoader
from torch.utils.data import random_split
 
from torchvision.datasets import MNIST
from torchvision import transforms

class MnistDataModule(pl.LightningDataModule):
    def __init__(self, data_path='./'):
        super().__init__()
        self.data_path = data_path
        self.transform = transforms.Compose([transforms.ToTensor()])
        
    def prepare_data(self):
        MNIST(root=self.data_path, download=True) 

    def setup(self, stage=None):
        # stage is either 'fit', 'validate', 'test', or 'predict'
        # here note relevant
        mnist_all = MNIST( 
            root=self.data_path,
            train=True,
            transform=self.transform,  
            download=False
        ) 

        self.train, self.val = random_split(
            mnist_all, [55000, 5000], generator=torch.Generator().manual_seed(1)
        )

        self.test = MNIST( 
            root=self.data_path,
            train=False,
            transform=self.transform,  
            download=False
        ) 

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=64, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=64, num_workers=4)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=64, num_workers=4)
    
    
torch.manual_seed(1) 
mnist_dm = MnistDataModule()

In [16]:
# init data module class 
torch.manual_seed(1)
mnist_dm = MnistDataModule()

In [17]:
from pytorch_lightning.callbacks import ModelCheckpoint


mnistclassifier = MultiLayerPerceptron()

trainer = pl.Trainer(max_epochs=10)

trainer.fit(model=mnistclassifier, datamodule=mnist_dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | train_acc | MulticlassAccuracy | 0     
1 | valid_acc | MulticlassAccuracy | 0     
2 | test_acc  | MulticlassAccuracy | 0     
3 | model     | Sequential         | 25.8 K
-------------------------------------------------
25.8 K    Trainable params
0         Non-trainable params
25.8 K    Total params
0.103     Total estimated model params size (MB)


Epoch 9: 100%|██████████| 939/939 [00:09<00:00, 98.05it/s, loss=0.1, v_num=1, train_loss=0.260, valid_loss=0.166, valid_acc=0.949]      

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 939/939 [00:09<00:00, 97.95it/s, loss=0.1, v_num=1, train_loss=0.260, valid_loss=0.166, valid_acc=0.949]


In [25]:
#https://github.com/pytorch/pytorch/issues/40500

In [23]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 53331), started 0:03:57 ago. (Use '!kill 53331' to kill it.)