In [1]:
!pip install --no-deps d2l

Collecting d2l
  Obtaining dependency information for d2l from https://files.pythonhosted.org/packages/8b/39/418ef003ed7ec0f2a071e24ec3f58c7b1f179ef44bec5224dcca276876e3/d2l-1.0.3-py3-none-any.whl.metadata
  Downloading d2l-1.0.3-py3-none-any.whl.metadata (556 bytes)
Downloading d2l-1.0.3-py3-none-any.whl (111 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m111.7/111.7 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: d2l
Successfully installed d2l-1.0.3


In [2]:
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.tuner import Tuner
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision.datasets import FashionMNIST
from torchvision.transforms import Compose, Resize, ToTensor


torch.set_float32_matmul_precision("medium")



In [3]:
conv_arch = {
    "11": ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512)),
    "13": ((2, 64), (2, 128), (2, 256), (2, 512), (2, 512)),
    "16": ((2, 64), (2, 128), (3, 256), (3, 512), (3, 512)),
    "19": ((2, 64), (2, 128), (4, 256), (4, 512), (4, 512)),
}

ratio = 4
small_conv_arch = {
    key: [(pair[0], pair[1] // ratio) for pair in value] for key, value in conv_arch.items()
}

In [4]:
class FashionMNISTDataModel(pl.LightningDataModule):
    def __init__(self, batch_size=128, data_dir="../data", num_workers=4):
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.num_workers = num_workers
        self.trans = Compose([ToTensor(), Resize(224, antialias=True)])

    def prepare_data(self):
        FashionMNIST(root=self.data_dir, download=True)

    def setup(self, stage=None):
        if stage == "fit" or stage is None:
            self.train_data, self.val_data = random_split(
                FashionMNIST(root=self.data_dir, train=True, transform=self.trans), [0.8, 0.2]
            )
        elif stage == "test" or stage is None:
            self.test_data = FashionMNIST(root=self.data_dir, train=False, transform=self.trans)

    def train_dataloader(self):
        return DataLoader(
            self.train_data,
            self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_data,
            self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            shuffle=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_data,
            self.batch_size,
            num_workers=self.num_workers,
            pin_memory=True,
            shuffle=False,
        )

In [5]:
def vgg_block(num_convs, in_channels, out_channels):
    layers = []
    for _ in range(num_convs):
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_channels = out_channels
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)

In [6]:
class Vgg(pl.LightningModule):
    def __init__(self, conv_arch, lr=0.05, weight_decay=0):
        super().__init__()
        # self.lr = lr
        # self.weight_decay = weight_decay
        self.save_hyperparameters()

        conv_blks = []
        in_channels = 1
        for num_convs, out_channels in conv_arch:
            conv_blks.append(vgg_block(num_convs, in_channels, out_channels))
            in_channels = out_channels

        self.features = nn.Sequential(*conv_blks)
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(out_channels * 7 * 7, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(4096, 10),
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode="fan_in", nonlinearity="relu")
                nn.init.constant_(m.bias, 0)

        self.loss = nn.CrossEntropyLoss()
        self.train_acc = Accuracy(task="multiclass", num_classes=10, average="micro")
        self.val_acc = Accuracy(task="multiclass", num_classes=10, average="micro")
        self.test_acc = Accuracy(task="multiclass", num_classes=10, average="micro")

    def forward(self, X):
        X = self.features(X)
        X = self.classifier(X)
        return X

    def training_step(self, batch):
        X, y = batch
        y_hat = self(X)
        loss = self.loss(y_hat, y)
        acc = self.train_acc(y_hat, y)
        metrics = {"train_loss": loss, "train_acc": acc}
        self.log_dict(metrics, prog_bar=True, sync_dist=True)
        return loss

    def validation_step(self, batch):
        X, y = batch
        y_hat = self(X)
        loss = self.loss(y_hat, y)
        acc = self.val_acc(y_hat, y)
        metrics = {"val_loss": loss, "val_acc": acc}
        self.log_dict(metrics, prog_bar=True, sync_dist=True)
        return loss
    
    def on_validation_end(self):
        self.logger.log_hyperparams(
            self.hparams,
            {
                "hp/train_acc": self.train_acc.compute(),
                "hp/val_acc": self.val_acc.compute(),
                "hp/test_acc": self.test_acc.compute(),
            },
        )

    def test_step(self, batch):
        X, y = batch
        y_hat = self(X)
        loss = self.loss(y_hat, y)
        acc = self.test_acc(y_hat, y)
        metrics = {"test_loss": loss, "test_acc": acc}
        self.log_dict(metrics, prog_bar=True, sync_dist=True)
        return loss

    def on_test_end(self):
        self.logger.log_hyperparams(
            self.hparams,
            {
                "hp/train_acc": self.train_acc.compute(),
                "hp/val_acc": self.val_acc.compute(),
                "hp/test_acc": self.test_acc.compute(),
            },
        )

    def configure_optimizers(self):
        return torch.optim.SGD(
            self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
        )

In [7]:
data = FashionMNISTDataModel(batch_size=128)

In [8]:
model = Vgg(small_conv_arch["11"])
trainer = pl.Trainer(max_epochs=10, logger=None,
                     accelerator="gpu", devices=2, strategy="ddp_notebook")

In [9]:
%%time
trainer.fit(model, datamodule=data)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 26421880/26421880 [00:01<00:00, 17956162.85it/s]


Extracting ../data/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 29515/29515 [00:00<00:00, 299664.94it/s]


Extracting ../data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 4422102/4422102 [00:00<00:00, 5538664.08it/s]


Extracting ../data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../data/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 5148/5148 [00:00<00:00, 30411657.74it/s]


Extracting ../data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/FashionMNIST/raw



Sanity Checking: |          | 0/? [00:00<?, ?it/s]



Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

CPU times: user 7.45 s, sys: 3.07 s, total: 10.5 s
Wall time: 7min 14s
