# Imports

In [1]:
import json
import os
import sys
import time
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import torch
import torch.nn as nn

from matplotlib import cm as CM
from scipy import io
from scipy.ndimage import gaussian_filter
from sklearn.neighbors import KDTree
from torchinfo import summary
from torchvision.transforms import v2

In [2]:
drive = None
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
PATH = "./"

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
PATH = PATH if drive is None else "/content/drive/MyDrive/self-learn/crowd-counting"

In [3]:
sys.path.append(PATH)

from constants import *
from models import CSRNet

In [4]:
root = root if drive is None else "/content/drive/MyDrive/crowd-counting"
path = path  #######

In [5]:
MODEL_NAME = "CSRNet"

# Model

In [6]:
model = CSRNet()
summary(model, (8, 3, 224, 224))

Layer (type:depth-idx)                   Output Shape              Param #
CSRNet                                   [8, 1, 224, 224]          --
├─Sequential: 1-1                        [8, 512, 28, 28]          --
│    └─Conv2d: 2-1                       [8, 64, 224, 224]         1,792
│    └─ReLU: 2-2                         [8, 64, 224, 224]         --
│    └─Conv2d: 2-3                       [8, 64, 224, 224]         36,928
│    └─ReLU: 2-4                         [8, 64, 224, 224]         --
│    └─MaxPool2d: 2-5                    [8, 64, 112, 112]         --
│    └─Conv2d: 2-6                       [8, 128, 112, 112]        73,856
│    └─ReLU: 2-7                         [8, 128, 112, 112]        --
│    └─Conv2d: 2-8                       [8, 128, 112, 112]        147,584
│    └─ReLU: 2-9                         [8, 128, 112, 112]        --
│    └─MaxPool2d: 2-10                   [8, 128, 56, 56]          --
│    └─Conv2d: 2-11                      [8, 256, 56, 56]          29

# Dataset preprocessing and loading

In [63]:
class AdaptiveResize:
    def __init__(self, out_shape):
        self.out_shape = out_shape

    def __call__(self, img):
        img = v2.Resize(self.out_shape, antialias=True)(img)
        return img

In [64]:
dmap_transform = v2.Compose(
    [
        lambda x: np.expand_dims(x, -1),  ## add channel dim to dmap
        v2.ToImage(),
        v2.ToDtype(
            torch.float32, scale=True
        ),  ## these two are equivalent to the deprecated v2.ToTensor()
    ]
)

img_transform = v2.Compose(
    [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.ColorJitter(),
        v2.RandomGrayscale(p=0.1),
        lambda x: v2.functional.adjust_gamma(
            x, gamma=torch.FloatTensor(1).uniform_(0.8, 1.2)
        ),
        # lambda x: x/255
    ]
)

### warning: if multiprocessing w/ num_workers > 0, lambda functions may break and custom classes may be needed

In [131]:
class CrowdDataset(torch.utils.data.Dataset):
    """
    CrowdDataset for ShanghaiTech_A.
    Indexing returns a tuple (img, gt_dmap), where `img` is a 3D tensor (C=3, W, H),
    and gt_dmap is the ground truth density map (1, W//ds_scale, H//ds_scale).
    ds_scale (int=4): Downsampling scale factor. In the case of MCNN, the model predicts a density map of size (W//4, H//4), so the ground
    truth map must be appropriately resized, and for CSRNet, the scale is 8.
    interpolate (bool=False): If True (for CSRNet), the model attempts to interpolate its downscaled output map back to the original
    image dimensions. i.e. (540 x 256) would become (540//ds_scale * ds_scale x 256//ds_scale * ds_scale) = (536 x 256)
    """

    def __init__(self, split, path="ShanghaiTech_A", ds_scale=4, interpolate=False):

        assert split in ["train", "test"], "`split` must be either `train` or `test`."
        self.data_path = f"{path}/{split}_data"
        self.ds_scale = ds_scale
        self.interpolate = interpolate

    def __getitem__(self, index):

        try:
            gt_dmap = np.load(f"{self.data_path}/gt_maps/GT_IMG_{index+1}.npy")
            img = plt.imread(
                f"{self.data_path}/images/IMG_{index+1}.jpg"
            ).copy()  # copy for writability
            if len(img.shape) == 2:
                img = np.repeat(
                    img[:, :, np.newaxis], 3, axis=-1
                )  ## convert B&W to RGB via repeat
        except FileNotFoundError as e:
            raise Exception("File not found. Index may be out of bounds.") from e

        if self.interpolate:
            out_shape = (
                img.shape[0] // self.ds_scale * self.ds_scale,
                img.shape[1] // self.ds_scale * self.ds_scale,
            )
        else:
            out_shape = (img.shape[0] // self.ds_scale, img.shape[1] // self.ds_scale)

        gt_dmap = dmap_transform(gt_dmap)
        gt_dmap = AdaptiveResize(out_shape)(gt_dmap)
        if not self.interpolate:
            gt_dmap *= self.ds_scale ** 2  ## re-normalize values after downsampling
        img = img_transform(img)

        return img, gt_dmap

    def __len__(self):
        return len(os.listdir(f"{self.data_path}/images"))

In [132]:
train_data = CrowdDataset(split="train", ds_scale=8, interpolate=True)
test_data = CrowdDataset(split="test", ds_scale=8, interpolate=True)

## Dataloaders

#### Note we use a batch size of 1, because PyTorch Dataloaders require identically sized images.
To circumvent this, there are a number of methods, such as writing a custom collate function. We will perform **gradient accumulation** over `BATCH_SIZE` steps in our training function later.

In [133]:
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=True)

In [134]:
# for img, gt in train_loader:
#     break

## Train

In [138]:
def train(model, train_loader, val_loader, optimizer, criterion, device):

    model.train()
    model.to(device)
    train_losses, val_losses = [], []
    val_accuracies = []

    for epoch in range(EPOCHS):

        print(f"Epoch {epoch+1}/{EPOCHS}")

        for step, (img, gt) in enumerate(train_loader):

            img, gt = img.to(device), gt.to(device)

            out = model(img)
            loss = criterion(out, gt)
            loss /= BATCH_SIZE
            train_losses.append(loss.item())  # every step
            loss.backward()

            ########### NOTE: fix later, consider when last batch is not of size BATCH_SIZE!
            if (step + 1) % BATCH_SIZE == 0 or (step + 1) == len(train_loader):

                # # monitor overall gradient norm
                # grads = [
                #     param.grad.detach().flatten()
                #     for param in model.parameters()
                #     if param.grad is not None
                # ]
                # norm = torch.cat(grads).norm()

                optimizer.step()
                optimizer.zero_grad()

            if (step + 1) % (BATCH_SIZE * PRINT_ITERS) == 0 and step != 0:

                mae = abs(out.sum() - gt.sum())
                print(
                    f"Step: {step+1}/{len(train_loader)} | Loss: {loss.item():.2e} |",
                    f"Pred: {out.sum():.3f} | True: {gt.sum():.3f} |",
                    f"MAE: {mae:.3f}",
                )

                # val_loss, val_acc = eval(model, val_loader, criterion, device)
                # val_losses.append(val_loss)
                # val_accuracies.append(val_acc)
                # print(
                #     f"Step: {step}/{len(train_loader)}, Running Average Loss: {np.mean(train_losses):.3f} |",
                #     f"Val Loss: {val_loss:.3f} | Val Acc: {val_acc:.3f} | Grad Norm: {norm:.2f}",
                # )
                # model.train()

        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
            },
            f"{path}/checkpoints/{MODEL_NAME}_EPOCH_{epoch+1}_SEED_{SEED}.pt",
        )

        with open(
            f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_train_losses.json", "w"
        ) as f:
            json.dump(train_losses, f)

        # with open(
        #     f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_losses.json", "w"
        # ) as f2:
        #     json.dump(val_losses, f2)

        # with open(
        #     f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_val_accuracies.json", "w"
        # ) as f3:
        #     json.dump(val_accuracies, f3)

        # torch.save(deltas, f"{path}/train_logs/{MODEL_NAME}_SEED_{SEED}_deltas.pt")

    # return train_losses, val_losses, val_accuracies
    return train_losses

In [139]:
model = CSRNet()
model.to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
criterion = nn.MSELoss()

## Driver code

In [140]:
BATCH_SIZE = 2  ### for now. move to 32/64 etc.
PRINT_ITERS = 1

In [None]:
## Driver code
train_losses = train(model, train_loader, test_loader, optimizer, criterion, device)
# train_losses, val_losses, val_accuracies = train(
#     model, train_loader, test_loader, optimizer, criterion, device
# )

Epoch 1/2
Step: 2/300 | Loss: 8.32e-06 | Pred: 32.190 | True: 332.012 | MAE: 299.822
Step: 4/300 | Loss: 8.99e-02 | Pred: 215902.000 | True: 1157.785 | MAE: 214744.219


## todo: eval function on val and test

#### temporary informal inference

In [423]:
test_idx = 64
gt_dmap = np.load(f"{data_path}/test_data/gt_maps/GT_IMG_{test_idx+1}.npy")
img = plt.imread(f"{data_path}/test_data/images/IMG_{test_idx+1}.jpg").copy()
if len(img.shape) == 2:
    img = np.repeat(img[:, :, np.newaxis], 3, axis=-1)
out_shape = (img.shape[0] // 4, img.shape[1] // 4)

gt_dmap = dmap_transform(gt_dmap)
gt_dmap = AdaptiveResize(out_shape)(gt_dmap)
gt_dmap *= 4 ** 2  ## re-normalize values after resizing
img = img_transform(img)

In [428]:
model.eval()
with torch.no_grad():
    out = model(img.unsqueeze(0)).squeeze(0)

In [None]:
fig, ax = plt.subplots(1, 3, figsize=(16, 12))
ax[0].imshow(img.cpu().permute(1, 2, 0))
ax[0].set_title("Original Image")
ax[1].imshow(gt_dmap.cpu().permute(1, 2, 0))
ax[1].set_title(f"Ground Truth DM: {gt_dmap.sum():.3f}")
ax[2].imshow(out.cpu().permute(1, 2, 0))
ax[2].set_title(f"Pred DM: {out.sum():.3f}")
plt.show()