In [None]:
!pip install torch~=1.7.0 torchvision pytorch-lightning

In [None]:
!nvidia-smi

In [None]:
import torch
torch.__version__

In [4]:
from torchvision.models import resnet18
from torch import nn
from torch.utils.data import DataLoader

In [6]:
import torchvision
from torchvision import datasets 
from torchvision.transforms import ToTensor

In [7]:
path_to_train ='../input/bwcolor/MNIST-Data-BW&Color/Train'
path_to_test ='../input/bwcolor/MNIST-Data-BW&Color/Val'

In [8]:
train_classifier_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
           mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
])
test_transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
           mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
])

In [9]:
train_set = datasets.ImageFolder(root=path_to_train, transform= train_classifier_transforms)
test_set = datasets.ImageFolder(root= path_to_test , transform = test_transforms)

In [10]:
train_dl = DataLoader(train_set, batch_size=512, shuffle=True)
test_dl = DataLoader(test_set, batch_size=512)

In [11]:
import pytorch_lightning as pl
from pytorch_lightning.core.decorators import auto_move_data

In [None]:
class ResNetMNIST(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.model = resnet18(num_classes=10)
    self.loss = nn.CrossEntropyLoss()

  @auto_move_data
  def forward(self, x):
    return self.model(x)
  
  def training_step(self, batch, batch_no):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    return loss
  
  def configure_optimizers(self):
    return torch.optim.RMSprop(self.parameters(), lr=0.005)

In [13]:
model = ResNetMNIST()

In [None]:
trainer = pl.Trainer(
    gpus=0,
    max_epochs=1,
    progress_bar_refresh_rate=20
)
trainer.fit(model, train_dl)

In [15]:
trainer.save_checkpoint("resnet18_mnist.pt")

In [16]:
def get_prediction(x, model: pl.LightningModule):
  model.freeze() # prepares model for predicting
  probabilities = torch.softmax(model(x), dim=1)
  predicted_class = torch.argmax(probabilities, dim=1)
  return predicted_class, probabilities

In [17]:
from tqdm.autonotebook import tqdm

In [18]:
inference_model = ResNetMNIST.load_from_checkpoint("./resnet18_mnist.pt")

In [19]:
true_y, pred_y = [], []
for batch in tqdm(iter(test_dl), total=len(test_dl)):
  x, y = batch
  true_y.extend(y)
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())

  0%|          | 0/20 [00:00<?, ?it/s]

In [20]:
from sklearn.metrics import classification_report

In [21]:
print(classification_report(true_y, pred_y, digits=2))

              precision    recall  f1-score   support

           0       0.97      0.55      0.70       980
           1       0.33      1.00      0.49      1135
           2       0.95      0.74      0.83      1032
           3       0.98      0.63      0.76      1010
           4       0.37      0.60      0.46       982
           5       0.97      0.70      0.82       892
           6       0.62      0.71      0.66       958
           7       0.97      0.56      0.71      1028
           8       0.90      0.53      0.67       974
           9       0.90      0.01      0.02      1009

    accuracy                           0.61     10000
   macro avg       0.79      0.60      0.61     10000
weighted avg       0.79      0.61      0.61     10000

