In [None]:
%load_ext watermark
%watermark -p torch,lightning,torchvision

In [2]:
import lightning as L
import torch
import torchvision
import torch.nn.functional as F
import torchmetrics
from lightning.pytorch.loggers import CSVLogger
import matplotlib.pyplot as plt
import numpy as np

from shared_utilities import LightningModel,Cifar10DataModule,plot_loss_and_acc

In [None]:
import torch

entrypoints = torch.hub.list('pytorch/vision',force_reload=True)
for e in entrypoints:
    if "resnet" in e:
        print(e)

In [None]:
%%capture --no-display
L.pytorch.seed_everything(123)

repo = "pytorch/vision"
pytorch_model = torch.hub.load(repo, "resnet18", weights="IMAGENET1K_V1")


In [5]:
# Also

# from torchvision.models import resnet18, ResNet18_Weights
# weights = ResNet18_Weights.IMAGENET1K_V1

In [None]:
pytorch_model.fc

In [7]:
# Training the last layer

for param in pytorch_model.parameters():
    param.requires_grad = False

pytorch_model.fc = torch.nn.Linear(512, 10)

In [None]:
# it is good to apply similar data transformations used for pretraining training


from torchvision.models import resnet18, ResNet18_Weights
weights = ResNet18_Weights.IMAGENET1K_V1
preprocess_transform = weights.transforms()
preprocess_transform

In [None]:
L.pytorch.seed_everything(123)

dm = Cifar10DataModule(
    batch_size=64, 
    num_workers=4,
    train_transform=preprocess_transform,
    test_transform=preprocess_transform
)
dm.prepare_data()
dm.setup()

In [None]:
for images, labels in dm.train_dataloader():
    break

plt.figure(figsize=(8,8))
plt.axis('off')
plt.title("training image")
plt.imshow(
    np.transpose(
        torchvision.utils.make_grid(
            images[:64],
            padding=1,
            pad_value=1.0,
            normalize=True
        ),
        (1,2,0)
    )
)
plt.show()

In [None]:
images[0].shape

In [None]:
L.pytorch.seed_everything(123)

dm = Cifar10DataModule(
    height_width=(224,224)
    batch_size=64, 
    num_workers=4,
    train_transform=preprocess_transform,
    test_transform=preprocess_transform
)

l_model = LightningModel(model=pytorch_model, learning_rate=0.1, num_classes=10)

trainer= L.Trainer(
    # fast_dev_run=1,
    max_epochs=30,
    accelerator='gpu',
    devices=1,
    logger=CSVLogger(save_dir='logs/', name='cifr-resnet18-last-layer'),
    deterministic=True
)

In [None]:
trainer.fit(model=l_model, datamodule=dm)

In [None]:
plot_loss_and_acc(trainer.logger.log_dir)

In [None]:
trainer.test(model=l_model, datamodule=dm)