In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import vgg19
import pytorch_lightning as pl
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

# Define the LightningModule
class VGG19Classifier(pl.LightningModule):
    def __init__(self, num_classes):
        super(VGG19Classifier, self).__init__()
        # Load the pretrained VGG19 model
        self.model = vgg19(pretrained=True)
        # Replace the classifier to fit the number of classes in CIFAR dataset
        self.model.classifier[6] = nn.Linear(4096, num_classes)

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, targets)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, targets)
        acc = (outputs.argmax(dim=1) == targets).float().mean()
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.model.parameters(), lr=1e-4)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]

# Data preparation
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224 for VGG19
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def prepare_data():
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
    return train_loader, val_loader

# Main training script
if __name__ == "__main__":
    train_loader, val_loader = prepare_data()

    model = VGG19Classifier(num_classes=10)  # CIFAR10 has 10 classes

    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath='./checkpoints',
        filename='vgg19-cifar10-{epoch:02d}-{val_loss:.2f}',
        save_top_k=1,
        mode='min'
    )

    logger = TensorBoardLogger("tb_logs", name="vgg19-cifar10")

    trainer = Trainer(
        max_epochs=20,
        accelerator='gpu',
        devices=1,
        callbacks=[checkpoint_callback],
        logger=logger
    )

    trainer.fit(model, train_loader, val_loader)
