In [None]:
# default_exp model.resnet18

In [None]:
%load_ext lab_black
%load_ext autoreload
%autoreload 2

In [None]:
# export
import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl

from pytorch_lightning.metrics.functional import accuracy
from loguru import logger

In [None]:
# export
"""ResNet in PyTorch.

For Pre-activation ResNet, see 'preact_resnet.py'.

Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
    Deep Residual Learning for Image Recognition. arXiv:1512.03385
"""


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(
            planes, self.expansion * planes, kernel_size=1, bias=False
        )
        self.bn3 = nn.BatchNorm2d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=stride,
                    bias=False,
                ),
                nn.BatchNorm2d(self.expansion * planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


# def ResNet18():
#     return ResNet(BasicBlock, [2, 2, 2, 2])


# def ResNet34():
#     return ResNet(BasicBlock, [3, 4, 6, 3])


# def ResNet50():
#     return ResNet(Bottleneck, [3, 4, 6, 3])


# def ResNet101():
#     return ResNet(Bottleneck, [3, 4, 23, 3])


# def ResNet152():
#     return ResNet(Bottleneck, [3, 8, 36, 3])

In [None]:
# export
class RESNET18(pl.LightningModule):
    def __init__(self, num_classes, learning_rate=2e-4):

        super().__init__()

        self.save_hyperparameters()

        self.num_classes = num_classes
        self.learning_rate = learning_rate

        self.model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

    def forward(self, x):
        x = self.model(x)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return {"loss": loss}

    def training_epoch_end(self, outputs):
        loss = torch.mean(torch.tensor([x["loss"] for x in outputs]))

        self.log("train_loss", loss)

    def validation_step(self, batch, batch_idx):

        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)

        return {"val_loss": loss, "val_acc": acc}

    def validation_epoch_end(self, outputs):

        avg_acc = torch.stack([x["val_acc"] for x in outputs]).mean()
        avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()

        self.log("val_loss", avg_loss, prog_bar=True)
        self.log("val_acc", avg_acc, prog_bar=True)

        logger.info(f"val_loss: {avg_loss}, val_acc: {avg_acc}")

    def test_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)

        acc = accuracy(preds, y)

        return {"test_acc": acc}

    def test_epoch_end(self, outputs):
        avg_acc = torch.stack([x["test_acc"] for x in outputs]).mean()

        self.log("test_acc", avg_acc)

        logger.info(f"test_acc: {avg_acc}")

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [None]:
from fractal_augmentation.data.cifar import CIFAR10DataModule

In [None]:
dm = CIFAR10DataModule()
model = RESNET18(num_classes=dm.num_classes)
trainer = pl.Trainer(max_epochs=1, progress_bar_refresh_rate=20)
trainer.fit(model, dm)

GPU available: False, used: False
TPU available: None, using: 0 TPU cores


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified



  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 11.2 M
---------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

2020-12-14 16:13:03.342 | INFO     | __main__:validation_epoch_end:49 - val_loss: 2.3044912815093994, val_acc: 0.0625


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…






1

In [None]:
trainer.test(model, datamodule=dm)

Files already downloaded and verified


HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

2020-12-14 15:55:24.693 | INFO     | __main__:test_epoch_end:78 - test_acc: 0.4080471098423004



--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.4080),
 'val_acc': tensor(0.4017),
 'val_loss': tensor(1.7144)}
--------------------------------------------------------------------------------


[{'val_loss': 1.7144399881362915,
  'val_acc': 0.40167197585105896,
  'test_acc': 0.4080471098423004}]

In [None]:
print('\x1b[31m"red"\x1b[0m')

[31m"red"[0m


In [None]:
from nbdev.export import notebook2script

notebook2script()

Converted 01_model.ipynb.
Converted 0a_data.cifar.ipynb.
Converted dev_setup.ipynb.
Converted index.ipynb.
