In [19]:
import os

import torch
import pytorch_lightning as pl
from torch import nn

from torchvision import models

In [18]:
class ImageNet10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, path=os.path.join('.', 'imagenet-10-dataset')):
        super(ImageNet10DataModule, self).__init__()
        self.batch_size = batch_size

In [12]:
class EfficientNetLightningModel(pl.LightningModule):
    def __init__(self, n_classes=10):
        super(EfficientNetLightningModel, self).__init__()
        self.model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.IMAGENET1K_V2)
        self.model.classifier[1] = nn.Linear(self.model.classifier[1].in_features, n_classes)
        self.criterion = nn.CrossEntropyLoss()
        
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        logits = self(inputs)
        loss = self.criterion(logits, labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        images, labels = batch
        logits = self(images)
        loss = self.criterion(logits, labels)
        predictions = torch.argmax(logits, dim=1)
        acc = torch.sum(predictions == labels.data).item() / len(labels)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)
        
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

In [13]:
model = EfficientNetLightningModel()