In [1]:
import pytorch_lightning as pl
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import os
from torchmetrics import Accuracy

In [2]:
class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.test_accuracy = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

    def training_step(self, batch, batch_idx):
        X, y = batch
        # Compute prediction error
        pred = model(X)
        loss = self.loss_fn(pred, y)
        return loss
    
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        pred = model(x)
        test_accuracy = sum(pred.argmax(1) == y)/len(y)
        test_loss = self.loss_fn(pred, y)
        #self.test_accuracy.update(pred, y)

        # Calling self.log will surface up scalars for you in TensorBoard
        self.log("test_loss", test_loss, prog_bar=True)
        self.log("test_accuracy", test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.0001)

In [3]:
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=transforms.ToTensor())
#train_loader = DataLoader(FashionMNIST('./data/', download=True, transform=transforms.ToTensor()), batch_size=64)
train_loader = DataLoader(training_data, batch_size=64, num_workers=os.cpu_count()//2)
test_loader = DataLoader(test_data, batch_size=64, num_workers=os.cpu_count()//2)

In [4]:
trainer = pl.Trainer(max_epochs=3)
model = LitModel()
trainer.fit(model, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name              | Type               | Params
---------------------------------------------------------
0 | flatten           | Flatten            | 0     
1 | linear_relu_stack | Sequential         | 1.2 M 
2 | loss_fn           | CrossEntropyLoss   | 0     
3 | test_accuracy     | MulticlassAccuracy | 0     
---------------------------------------------------------
1.2 M     Trainable params
0         Non-trainable params
1.2 M     Total params
4.760     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.


In [5]:
trainer.test(model, dataloaders=test_loader)

Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy          0.847599983215332
        test_loss           0.4366007447242737
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.4366007447242737, 'test_accuracy': 0.847599983215332}]

In [6]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

count = 0
correct = 0
for data in test_data:
    # x, y = test_data[i][0], test_data[i][1]
    x, y = data[0], data[1]
    with torch.no_grad():
        pred = model(x)
        predicted, actual = classes[pred[0].argmax(0)], classes[y]
        # print(f'Predicted: "{predicted}", Actual: "{actual}"')
        count += 1
        if predicted == actual:
            correct += 1
print(f'{correct} correct out of {count}')

8476 correct out of 10000


In [7]:
batch_iterator = iter(test_loader)
batch1 = next(batch_iterator)
batch2 = next(batch_iterator)
print(batch1[1])
print(batch2[1])
print(sum(batch1[1] == batch2[1]))

tensor([9, 2, 1, 1, 6, 1, 4, 6, 5, 7, 4, 5, 7, 3, 4, 1, 2, 4, 8, 0, 2, 5, 7, 9,
        1, 4, 6, 0, 9, 3, 8, 8, 3, 3, 8, 0, 7, 5, 7, 9, 6, 1, 3, 7, 6, 7, 2, 1,
        2, 2, 4, 4, 5, 8, 2, 2, 8, 4, 8, 0, 7, 7, 8, 5])
tensor([1, 1, 2, 3, 9, 8, 7, 0, 2, 6, 2, 3, 1, 2, 8, 4, 1, 8, 5, 9, 5, 0, 3, 2,
        0, 6, 5, 3, 6, 7, 1, 8, 0, 1, 4, 2, 3, 6, 7, 2, 7, 8, 5, 9, 9, 4, 2, 5,
        7, 0, 5, 2, 8, 6, 7, 8, 0, 0, 9, 9, 3, 0, 8, 4])
tensor(4)
