In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import lightning as L
from lightning import Trainer
import torch.nn.functional as F

In [2]:
input_size = 784 # 24 x 24
hidden_size = 100
num_class = 10
num_epochs = 30
batch_size = 100
lr = 0.001

In [38]:
class NuralNet(L.LightningModule):
    def __init__(self, input_size, hidden_size, output_size):
        super(NuralNet, self).__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        return out
    
    def training_step(self, batch, batch_idx):
        images, labels = batch
        images = images.reshape(-1, 28*28)
        
        outputs = self(images)
        
        loss = F.cross_entropy(outputs, labels)
        
        self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        
        images, labels = batch
        images = images.reshape(-1, 28*28)
        
        outputs = self(images)
        
        loss = F.cross_entropy(outputs, labels)
        
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
        
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=lr)
    
    def train_dataloader(self):
        train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=11, persistent_workers=True)
        return train_loader
    
    def val_dataloader(self):
        val_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
        val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=11, persistent_workers=True)
        return val_loader
    
   

In [39]:
trainer = Trainer(max_epochs = 2,fast_dev_run=False)
model = NuralNet(input_size, hidden_size, num_class)
trainer.fit(model)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name | Type   | Params
--------------------------------
0 | l1   | Linear | 78.5 K
1 | relu | ReLU   | 0     
2 | l2   | Linear | 1.0 K 
--------------------------------
79.5 K    Trainable params
0         Non-trainable params
79.5 K    Total params
0.318     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.
