In [None]:
from google.colab import drive


drive.mount('/content/gdrive')

In [None]:
!cd /content/gdrive/MyDrive/ieAI

In [None]:
!git clone https://github.com/satchitchatterji/DistillNAM.git

In [None]:
!cd DistillNAM
!pip install -r requirements.txt

In [18]:
import os

import models
import pytorch_lightning as L
import seaborn as sns
import torch
from pl_bolts.datamodules import MNISTDataModule
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger
from torch.nn import functional as F
from torchmetrics.functional import accuracy

In [19]:
%config InlineBackend.figure_format = 'svg'

In [20]:
# reproducibility
SEED = 42
L.seed_everything(SEED)


Global seed set to 42


42

In [21]:
PATH_DATASETS = "~/datasets"
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
NUM_WORKERS = int(os.cpu_count() / 2)
PL_ROOT_DIR = "./MNIST_CNN"
print(
    f"Datasets root: {PATH_DATASETS} batch size: {BATCH_SIZE} n_workers: {NUM_WORKERS} lightning dir: {PL_ROOT_DIR}"
)

Datasets root: ~/datasets batch size: 64 n_workers: 4 lightning dir: ./MNIST_CNN


In [22]:
# structure taken from
# https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/cifar10-baseline.html
class LitMNIST(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.save_hyperparameters()
        self.model = models.MNISTCnn()

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self.model(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self.model(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

In [23]:
mnist_dm = MNISTDataModule(
    PATH_DATASETS, normalize=True, num_workers=NUM_WORKERS, batch_size=BATCH_SIZE
)
mnist_dm


<pl_bolts.datamodules.mnist_datamodule.MNISTDataModule at 0x7f612c03ead0>

In [24]:
model = LitMNIST()
model


LitMNIST(
  (model): MNISTCnn(
    (conv1): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (conv2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (dropout): Dropout(p=0.5, inplace=False)
    (classifier): Linear(in_features=1600, out_features=10, bias=True)
  )
)

In [25]:
trainer = L.Trainer(
    default_root_dir=PL_ROOT_DIR,
    max_epochs=3,
    accelerator="auto",
    devices=1 if torch.cuda.is_available() else None,
    logger=CSVLogger(save_dir=PL_ROOT_DIR),
    callbacks=[TQDMProgressBar(refresh_rate=10)],
)
trainer


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


<pytorch_lightning.trainer.trainer.Trainer at 0x7f612c05ae90>

In [26]:
trainer.fit(model, datamodule=mnist_dm)



  | Name  | Type     | Params
-----------------------------------
0 | model | MNISTCnn | 34.8 K
-----------------------------------
34.8 K    Trainable params
0         Non-trainable params
34.8 K    Total params
0.139     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [None]:
trainer.test(model, datamodule=mnist_dm)


In [None]:
metrics = pd.read_csv(f"{trainer.logger.log_dir}/metrics.csv")
metrics


In [None]:
del metrics["step"]
metrics.set_index("epoch", inplace=True)
display(metrics.dropna(axis=1, how="all").head())


In [None]:
sns.relplot(data=metrics, kind="line")
