In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core import LightningModule
from torchmetrics import functional as FM
import matplotlib.pyplot as plt

In [2]:
training_data = datasets.FashionMNIST(root='data', train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=ToTensor())

In [4]:
train_dataset, val_dataset = random_split(training_data, [55000, 5000])

In [3]:
learning_rate, batch_size, epochs = 1e-3, 64, 10

In [6]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

In [8]:
from typing import Any


from pytorch_lightning.utilities.types import STEP_OUTPUT


class LitModel(LightningModule):
    def __init__(self):
        super(LitModel, self).__init__()
        self.model = nn.Sequential(nn.Flatten(),
                                   nn.Linear(28*28, 512),
                                   nn.BatchNorm1d(512),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(512, 256),
                                   nn.BatchNorm1d(256),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(256, 64),
                                   nn.BatchNorm1d(64),
                                   nn.ReLU(inplace=True),
                                   nn.Linear(64, 10)
                                   )
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        return loss
       
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = FM.accuracy(logits, y)
        loss = F.cross_entropy(logits, y)
        metrics = {'val_acc' : acc, 'val_loss' : loss}
        self.log_dict(metrics)
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        acc = FM.accuracy(logits, y)
        loss = F.cross_entropy(logits, y)
        metrics = {'test_acc' : acc, 'test_loss' : loss}
        self.log_dict(metrics)
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=learning_rate)