In [43]:
import os

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import nn
from torch.utils.data import random_split, Subset, DataLoader

from torchvision import models, datasets, transforms

In [44]:
class ImageNet10DataModule(pl.LightningDataModule):
    def __init__(
        self,
        dataset_dir=os.path.join(".", "imagenet-10-dataset"),
        batch_size=16,
        num_workers=2,
        train_test_split=0.8
    ):
        super(ImageNet10DataModule, self).__init__()
        self.dataset_dir = dataset_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.train_test_split = train_test_split
        
    def setup(self, stage=None):
        self.train_transforms = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomCrop(240),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        self.test_transforms = transforms.Compose([
            transforms.Resize((240, 240)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        full_dataset = datasets.ImageFolder(root=self.dataset_dir)
        
        n_total = len(full_dataset)
        n_train = int(self.train_test_split * n_total)
        n_test = n_total - n_train
        
        train_subset, test_subset = random_split(full_dataset, [n_train, n_test])

        train_indices = train_subset.indices
        test_indices = test_subset.indices

        train_dataset = datasets.ImageFolder(root=self.dataset_dir, transform=self.train_transforms)
        test_dataset = datasets.ImageFolder(root=self.dataset_dir, transform=self.test_transforms)

        self.train_dataset = Subset(train_dataset, train_indices)
        self.test_dataset = Subset(test_dataset, test_indices)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            persistent_workers=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            persistent_workers=True
        )

In [45]:
class EfficientNetLightningModel(pl.LightningModule):
    def __init__(self, n_classes=10):
        super(EfficientNetLightningModel, self).__init__()
        self.model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.IMAGENET1K_V2)
        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, n_classes)
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        logits = self(inputs)
        loss = self.criterion(logits, labels)
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = self.criterion(logits, labels)
        predictions = torch.argmax(logits, dim=1)
        acc = torch.sum(predictions == labels.data).item() / len(labels)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [46]:
model = EfficientNetLightningModel()

In [47]:
data_module = ImageNet10DataModule()

In [48]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath="./checkpoints",
    filename="best-checkpoint",
    save_top_k=1,
    mode="min"
)

trainer = pl.Trainer(
    max_epochs=25,
    devices=1,
    accelerator="gpu",
    callbacks=[checkpoint_callback]
)

trainer.fit(model, data_module)

Epoch 20: 100%|██████████| 650/650 [01:47<00:00,  6.03it/s, v_num=7, train_loss_step=0.00434, val_loss=0.122, val_acc=0.970, train_loss_epoch=0.034] 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/163 [00:00<?, ?it/s][A
Validation DataLoader 0:   1%|          | 1/163 [00:00<00:07, 20.97it/s][A
Validation DataLoader 0:   1%|          | 2/163 [00:00<00:08, 19.29it/s][A
Validation DataLoader 0:   2%|▏         | 3/163 [00:00<00:08, 18.92it/s][A
Validation DataLoader 0:   2%|▏         | 4/163 [00:00<00:08, 18.87it/s][A
Validation DataLoader 0:   3%|▎         | 5/163 [00:00<00:08, 18.71it/s][A
Validation DataLoader 0:   4%|▎         | 6/163 [00:00<00:08, 18.75it/s][A
Validation DataLoader 0:   4%|▍         | 7/163 [00:00<00:08, 18.72it/s][A
Validation DataLoader 0:   5%|▍         | 8/163 [00:00<00:08, 18.76it/s][A
Validation DataLoader 0:   6%|▌         | 9/163 [00:00<00:08, 18.79it/s][A


`Trainer.fit` stopped: `max_epochs=25` reached.


Epoch 24: 100%|██████████| 650/650 [02:00<00:00,  5.39it/s, v_num=7, train_loss_step=0.0355, val_loss=0.149, val_acc=0.959, train_loss_epoch=0.0382]


MisconfigurationException: No `test_step()` method defined to run `Trainer.test`.

In [50]:
best_model = EfficientNetLightningModel.load_from_checkpoint('./checkpoints/best-checkpoint-v1.ckpt')

In [51]:
best_model.eval()

EfficientNetLightningModel(
  (model): EfficientNet(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
              (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
         

In [52]:
trainer.validate(best_model, data_module)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation DataLoader 0: 100%|██████████| 163/163 [00:08<00:00, 18.78it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_acc            0.9903846383094788
        val_loss            0.03340562433004379
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.03340562433004379, 'val_acc': 0.9903846383094788}]