In [1]:
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

class MNISTDataset:
    def __init__(self, root_dir='data'):
        self.root_dir = root_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        self.download_dataset()

    def download_dataset(self):
        self.dataset = datasets.MNIST(self.root_dir, train=True, download=True, transform=self.transform)
        # Split dataset into train+validation and test
        self.train_dataset, self.val_dataset = random_split(self.dataset, [55000, 5000])
        self.test_dataset = datasets.MNIST(self.root_dir, train=False, download=True, transform=self.transform)

    def get_dataloaders(self, batch_size=32):
        train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
        test_loader = DataLoader(self.test_dataset, batch_size=batch_size, shuffle=False)
        return train_loader, val_loader, test_loader


In [2]:
import pytorch_lightning as pl
import torch.nn.functional as F
from torch import nn

class MNISTModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = nn.Linear(28 * 28, 128)
        self.layer_2 = nn.Linear(128, 256)
        self.layer_3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        x = F.relu(self.layer_1(x))
        x = F.relu(self.layer_2(x))
        return F.log_softmax(self.layer_3(x), dim=1)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('val_loss', loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log('test_loss', loss)


ModuleNotFoundError: No module named 'pytorch_lightning'