In [None]:
import pytorch_lightning as pl
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision import datasets
import os
from torchmetrics import Accuracy

In [None]:
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=transforms.ToTensor())
#train_loader = DataLoader(FashionMNIST('./data/', download=True, transform=transforms.ToTensor()), batch_size=64)
train_loader = DataLoader(training_data, batch_size=64, num_workers=os.cpu_count()//2)
test_loader = DataLoader(test_data, batch_size=64, num_workers=os.cpu_count()//2)

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
        )
        self.loss_fn = nn.CrossEntropyLoss()
        self.test_accuracy = Accuracy(task="multiclass", num_classes=10)

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

    def training_step(self, batch, batch_idx):
        X, y = batch
        # Compute prediction error
        pred = model(X)
        training_loss = self.loss_fn(pred, y)
        self.log("training_loss", training_loss, prog_bar=True)
        return training_loss
    
    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, y = batch
        pred = model(x)
        # test_accuracy = sum(pred.argmax(1) == y)/len(y)
        test_accuracy = self.test_accuracy(pred.argmax(1), y)
        test_loss = self.loss_fn(pred, y)
        self.log("test_loss", test_loss, prog_bar=False)
        self.log("test_accuracy", test_accuracy, prog_bar=True)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [None]:
trainer = pl.Trainer(max_epochs=10)
model = LitModel()
trainer.fit(model, train_dataloaders=train_loader)

In [None]:
trainer.test(model, dataloaders=test_loader)

In [None]:
classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

count = 0
correct = 0
for data in test_data:
    # x, y = test_data[i][0], test_data[i][1]
    x, y = data[0], data[1]
    with torch.no_grad():
        pred = model(x)
        predicted, actual = classes[pred[0].argmax(0)], classes[y]
        # print(f'Predicted: "{predicted}", Actual: "{actual}"')
        count += 1
        if predicted == actual:
            correct += 1
print(f'{correct} correct out of {count}')

In [None]:
batch_iterator = iter(test_loader)
batch1 = next(batch_iterator)
batch2 = next(batch_iterator)
print(batch1[1])
print(batch2[1])
print(sum(batch1[1] == batch2[1]))

In [None]:
model

In [None]:
batch_iterator = iter(test_loader)
x,y = next(batch_iterator)
predictions = model(x)

In [None]:
predictions.shape

In [None]:
model.test_accuracy(predictions.argmax(1), y)