In [None]:
%%capture
!pip install pytorch-lightning

In [None]:
from typing import Callable, Tuple

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import pandas as pd
import pytorch_lightning as pl
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

%matplotlib inline

In [None]:
BATCH_SIZE = 128
EPOCHS = 2
LR = 1e-3

In [None]:
df = pd.read_csv("/kaggle/input/100-bird-species/birds.csv")
df["filepaths"] = df["filepaths"].str.replace("\\", "/", regex=False)
prefix = "/kaggle/input/100-bird-species/"
df["filepaths"] = prefix + df["filepaths"]
le = LabelEncoder()
df["y"] = le.fit_transform(df["labels"])
df.head()

In [None]:
df["data set"].value_counts(), df["labels"].value_counts()

In [None]:
subset = df.sample(6).reset_index()
plt.figure(figsize=(12, 12))
for i in range(len(subset)):
    img = mpimg.imread(subset.loc[i, "filepaths"])
    label = subset.loc[i, "labels"]
    plt.subplot(3,2, i+1)
    plt.imshow(img)
    plt.title(label)
plt.show()

## Data

In [None]:
class Data(Dataset):
    def __init__(self, df: pd.DataFrame) -> None:
        self.files = df["filepaths"].values
        self.y = df["y"].values
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, i):
        return torchvision.io.read_image(self.files[i]) / 255.0, self.y[i]
    
train_ds = Data(df[df["data set"]=="train"])
valid_ds = Data(df[df["data set"]=="valid"])
test_ds = Data(df[df["data set"]=="test"])

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, num_workers=4, pin_memory=True)
valid_dl = DataLoader(valid_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, drop_last=False, num_workers=4, pin_memory=True)

## Model

In [None]:
# base_model = torchvision.models.resnet34(pretrained=True)
# base_model

In [None]:
class Model(nn.Module):
    def __init__(self, num_classes, freeze=True):
        super().__init__()
        self.base = torchvision.models.resnet34(pretrained=False)
        self.base.fc = nn.Identity()
        self.bn = nn.BatchNorm1d(512)
        self.linear = nn.Linear(512, num_classes)
        
        if freeze:
            self.base = self.base.eval()
            for p in self.base.parameters():
                p.requires_grad = False
        
    def forward(self, x):
        return self.linear(self.bn(self.base(x)))

class LightningModel(pl.LightningModule):
    def __init__(self, model: nn.Module, lr: float, loss_fn: Callable) -> None:
        super().__init__()
        self.model = model
        self.lr = lr
        self.loss_fn = loss_fn
        self.accuracy = lambda x, y: (x.argmax(-1) == y).float().mean()
        
    def common_step(
        self, 
        batch: Tuple[torch.FloatTensor, torch.LongTensor],
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
        x, y = batch
        logits = self.model(x)
        loss = self.loss_fn(logits, y)
        acc = self.accuracy(logits, y)
        return loss, acc
    
    def training_step(self, batch: Tuple[torch.FloatTensor, torch.LongTensor], *args) -> torch.FloatTensor:
        loss, acc = self.common_step(batch)
        self.log("training_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("training_accuracy", acc, on_step=True, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch: Tuple[torch.FloatTensor, torch.LongTensor], *args) -> None:
        loss, acc = self.common_step(batch)
        self.log("validation_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("validation_accuracy", acc, on_step=True, on_epoch=True, prog_bar=True)
                        
    def configure_optimizers(self) -> torch.optim.Optimizer:
        return torch.optim.Adam(self.model.parameters(), self.lr)

In [None]:
model = Model(len(le.classes_))
lightning_model = LightningModel(model, lr=LR, loss_fn=nn.CrossEntropyLoss())

## Training

In [None]:
trainer = pl.Trainer(
    max_epochs=EPOCHS,
    gpus=torch.cuda.device_count(),
)

trainer.fit(lightning_model, train_dl, valid_dl)

Accuracy is printed below.

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.eval().to(device)

y_preds = []
ys = []
for x, y in tqdm(test_dl):
    y_preds.append(model(x.to(device)).argmax(dim=-1))
    ys.append(y.to(device))
    
(torch.cat(y_preds) == torch.cat(ys)).float().mean()

## Shameless Self Promotion
If you wish to see more of content like this explained buy my [DL course](https://www.udemy.com/course/machine-learning-and-data-science-2021/?referralCode=E79228C7436D74315787) (usually $15).