# Image Classification on the UC Merced Land Use dataset

### Importing necessary libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision as tv
import albumentations as A

import pytorch_lightning as pl
import torchmetrics as tm
from torchinfo import summary

import datasets

import wandb

### Loading the dataset

In [None]:
ds_builder = datasets.load_dataset_builder("SatwikKambham/uc_merced_land_use")
labels = ds_builder.info.features["label"]
labels

In [None]:
class UCMercedLandUse(Dataset):
    def __init__(self, img_tfms, augs, split="train"):
        self.ds = datasets.load_dataset("SatwikKambham/uc_merced_land_use")['train'].train_test_split(test_size=0.3, shuffle=False)[split]
        self.ds_builder = datasets.load_dataset_builder("SatwikKambham/uc_merced_land_use")
        self.labels = self.ds_builder.info.features["label"]
        self.img_tfms = img_tfms
        self.augs = augs

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        if self.augs is None:
            img = self.img_tfms(np.array(self.ds[idx]["img"]))
        else:
            img = self.augs(image=np.array(self.ds[idx]["img"]))["image"]
            img = self.img_tfms(img)
        return {
            "img": img,
            "label": self.ds[idx]["label"],
        }

In [None]:
class UCMercedLandUseDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 32, num_workers=2):
        super().__init__()
        
        self.save_hyperparameters()
        
        self.img_tfms = tv.transforms.Compose(
            [
                tv.transforms.ToTensor(),
                tv.transforms.Normalize(
                    (0.48422758, 0.49005175, 0.45050276),
                    (0.17348297, 0.16352356, 0.15547496),
                ),
            ]
        )

        self.train_augs = A.Compose(
            [
                A.Resize(256, 256),
                A.RandomRotate90(),
            ]
        )
        self.test_augs = A.Compose(
            [
                A.Resize(256, 256),
            ]
        )
        
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage: str):
        self.train_ds = UCMercedLandUse(self.img_tfms, self.train_augs)
        self.test_ds = UCMercedLandUse(self.img_tfms, self.test_augs)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, num_workers=self.num_workers)

### Model architecture and training code

In [None]:
class Classifier(pl.LightningModule):
    def __init__(
        self,
        lr=1e-2,
        momentum=0.9,
        weight_decay=1e-4,
    ):
        super().__init__()
        
        self.save_hyperparameters()

        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay

        self.num_classes = 21
        
        self.model = tv.models.resnet18(weights=tv.models.ResNet18_Weights.DEFAULT)
        self.model.fc = nn.Sequential(
            nn.LazyLinear(self.num_classes),
        )

        test_input_size = (2, 3, 256, 256)
        test_input = torch.randn(test_input_size)
        _ = self.model(test_input)

        print(summary(self.model, input_data=test_input))

        self.criterion = nn.CrossEntropyLoss()

        self.accuracy = tm.Accuracy(
            task="multiclass",
            num_classes=self.num_classes,
        )

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        img, label = (
            batch["img"],
            batch["label"],
        )
        pred = self(img)
        loss = self.criterion(pred, label)

        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        img, label = (
            batch["img"],
            batch["label"],
        )

        pred = self(img)
        loss = self.criterion(pred, label)

        accuracy = self.accuracy(pred, label)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", accuracy, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.lr,
            momentum=self.momentum,
            weight_decay=self.weight_decay
        )
        return optimizer

In [None]:
# wandb_logger = pl.loggers.WandbLogger(
#     project="land_use_image_classification",
# )

lr_monitor = pl.callbacks.LearningRateMonitor()

dm = UCMercedLandUseDataModule(batch_size=16)
model = Classifier(lr=0.5)

trainer = pl.Trainer(
    accelerator="auto",
#     precision=16,
#     logger=[wandb_logger],
    callbacks=[lr_monitor],
#     fast_dev_run=True,
    max_epochs=10,
#     accumulate_grad_batches=,
#     gradient_clip_val=,
)

tuner = pl.tuner.Tuner(trainer)
tuner.lr_find(model, dm)

trainer.fit(model, dm)

In [None]:
# wandb.finish()