In [None]:
#@title Install dependencies
!pip install -q torch torchvision torchaudio-f https://download.pytorch.org/whl/torch_stable.html
!pip install -q timm==0.4.12
!pip install -q tensorboard
!pip3 install -q hub
!pip install -q pytorch_lightning

In [None]:
#@title Imports
import json
import numpy as np
import os
import math
import time
import random
import sys
from PIL import Image
import matplotlib.pyplot as plt
import timm
import torch
from torchvision import transforms
import albumentations as A
import cv2
import pytorch_lightning as pl
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision
from collections import OrderedDict
import timm.optim.optim_factory as optim_factory
import util.misc as misc
import models_mae_cross
from pytorch_lightning.loggers import TensorBoardLogger
from util.misc import NativeScalerWithGradNormCount as NativeScaler

if not torch.cuda.is_available():
    device=torch.device("cpu")
    print("Current device:", device)
else:
    device=torch.device("cuda")
    print("Current device:", device, "- Type:", torch.cuda.get_device_name(0))

In [None]:
#@title Load Best Checkpoint
paper = False #@param {type:"boolean"}
PATH = 'countrV6' #@param {type:"string"}
DATA_PATH = '../data/GalaxiesV2/' #@param {type:"string"}
ANNO_FILE = 'annotation.json' #@param {type:"string"}
DATA_SPLIT_FILE = 'train_test_val.json' #@param {type:"string"}
IM_DIR = 'images/' #@param {type:"string"}
GT_DIR = 'density_maps/' #@param {type:"string"}

anno_file = DATA_PATH + ANNO_FILE
data_split_file = DATA_PATH + DATA_SPLIT_FILE
im_dir = DATA_PATH + IM_DIR
gt_dir = DATA_PATH + GT_DIR

with open(anno_file) as f:
    annotations = json.load(f)

with open(data_split_file) as f:
    data_split = json.load(f)

# different loading pipeline between paper and lightning training 
if paper:
  PATH_TO_CKPT = f'./checkpoints/{path}.pth'
  args = torch.load(PATH_TO_CKPT)['args']
  args.resume = PATH_TO_CKPT
  model = models_mae_cross.__dict__['mae_vit_base_patch16'](norm_pix_loss=False)
  misc.load_model_FSC(args=args, model_without_ddp=model)

else:
  PATH_TO_CKPT = f'./checkpoints/{PATH}.ckpt'
  model = models_mae_cross.__dict__['mae_vit_base_patch16'](norm_pix_loss=False)
  state_dict = torch.load(PATH_TO_CKPT, map_location=torch.device('cpu'))['state_dict']
  pl_state_dict = OrderedDict([(key[6:], state_dict[key]) for key in state_dict.keys()])
  model.load_state_dict(pl_state_dict)

model.to(device)

In [None]:
#@title Modified Class Dataset for Testing
class StarsDataset(Dataset):
    def __init__(self, split, plot=False, transform=None):    

        # added a plot parameter to consider only 5 samples when loading data to plot    
        self.img = data_split[split][:5] if plot==True else data_split[split]
        self.img_dir = im_dir
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        im_id = self.img[idx]
        anno = annotations[im_id]
        bboxes = anno['box_examples_coordinates']

        rects = list()
        for bbox in bboxes:
            x1 = bbox[0][0]
            y1 = bbox[0][1]
            x2 = bbox[2][0]
            y2 = bbox[2][1]
            rects.append([y1, x1, y2, x2])

        dots = np.array(anno['points'])
        image = np.array(Image.open(im_dir+im_id))
        density = np.load(gt_dir+im_id[:-4] + '.npy').astype('float32')   
        m_flag = 0

        boxes = list()
        for box in rects:
            y1, x1, y2, x2 = [int(k) for k in box]  
            bbox = Image.fromarray(image[y1:y2+1, x1:x2+1, :])
            bbox = transforms.Resize((64, 64))(bbox)
            boxes.append(transforms.ToTensor()(bbox))
        boxes = torch.stack(boxes)

        if self.transform!=None:
            aug = self.transform(image=image, mask=density)
            image = aug['image']
            density = aug['mask']
        
        # boxes shape [3,3,64,64], image shape [3,384,384], density shape[384,384]   
        norm = A.Normalize()(image = image, mask = density)
        sample = {'image':norm['image'].transpose(2, 0, 1), 'dots':dots.shape[0], 'boxes':boxes, 'pos':rects, 'gt_map':density}

        return sample['image'], sample['dots'], sample['boxes'], sample['pos'], sample['gt_map']

In [None]:
#@title MAE Class
class customMAE(nn.Module):
    def __init__(self):
       super().__init__()
       self.mae = nn.L1Loss()

    def forward(self, yhat, y):       
        pred_cnt = torch.sum(yhat/factor, dim=(1,2))
        return self.mae(pred_cnt, y)

In [None]:
#@title RMSE Class
class customRMSE(nn.Module):
    def __init__(self):
       super().__init__()
       self.mse = nn.MSELoss()

    def forward(self, yhat, y):       
        pred_cnt = torch.sum(yhat/factor, dim=(1,2))
        return math.sqrt(self.mse(pred_cnt, y))

In [None]:
#@title Plot Predictions Function
def plot_predictions(fig, gt_map, output, path, batch_id):

    plt.figure(figsize = (30, 20));
    plt.subplot(1, 3, 1);
    plt.axis('off');
    plt.title('Input Image', fontsize=20);
    plt.imshow(fig.detach().cpu().numpy().transpose(1,2,0));

    plt.subplot(1, 3, 2);
    plt.axis('off');
    plt.title(f'Groundtruth Density = {int(torch.sum(gt_map/factor, dim=(1,2)))}', fontsize=20);
    plt.imshow(gt_map[0].detach().cpu(), cmap='gray');

    plt.subplot(1, 3, 3);
    plt.axis('off');
    plt.title(f'Predicted Density = {round(torch.sum(output/factor, dim=(1,2)).item(), 1)}', fontsize=20);
    plt.imshow(output.squeeze(0).detach().cpu(), cmap='hot');

    plt.savefig(f'../test/{path}/plots/{path}_{batch_id}')

In [None]:
#@title Test Dataloader
batch_size = 1
test_dataset = StarsDataset('test')
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

mae = customMAE()
rmse = customRMSE()
print('Number of test samples:', len(test_dl))

Number of test samples: 153


In [None]:
#@title Testing
model.eval()
test_MAE = 0
test_RMSE = 0

spread = True #@param {type:"boolean", description:"if True factor is 60"}
factor = 60 if spread else 0

shots = 'zero' #@param {type:"string", options:"either zero or few, if few automatic set to 3"}
num_ex = 0 if shots=='zero' else 3

if not os.path.isdir(f'../test/{PATH}'):
  os.mkdir(f'../test/{PATH}')
  os.mkdir(f'../test/{PATH}/few')
  os.mkdir(f'../test/{PATH}/zero')

with torch.no_grad():
  for batch_id, batch in enumerate(test_dl):  
      samples, gt_dots, boxes, pos, gt_map = batch
      samples = samples.to(device)
      gt_dots = gt_dots.to(device)
      boxes = boxes.to(device)
      gt_map = gt_map.to(device)

      output = model(samples, boxes, num_ex)
      test_MAE += mae(output, gt_dots)
      test_RMSE += rmse(output, gt_dots)

      fig = samples[0]
      box_map = torch.zeros([fig.shape[1],fig.shape[2]])
      box_map = box_map.to(device, non_blocking=True)
      for rect in pos:      
          for i in range(rect[2]-rect[0]):
              box_map[min(rect[0]+i,fig.shape[1]-1),min(rect[1],fig.shape[2]-1)] = 10
              box_map[min(rect[0]+i,fig.shape[1]-1),min(rect[3],fig.shape[2]-1)] = 10
          for i in range(rect[3]-rect[1]):
              box_map[min(rect[0],fig.shape[1]-1),min(rect[1]+i,fig.shape[2]-1)] = 10
              box_map[min(rect[2],fig.shape[1]-1),min(rect[1]+i,fig.shape[2]-1)] = 10
      box_map = box_map.unsqueeze(0).repeat(3,1,1) 
      pred = output.repeat(3,1,1)
      fig = fig + box_map + pred/2
      fig = torch.clamp(fig, 0, 1)
      plot_predictions(fig, gt_map, output, PATH, batch_id)

with open(f'../test/{PATH}/{shots}/metrics.txt', 'w') as f:
     f.write(f'Test MAE: {(test_MAE/len(test_dl)).item()}\nTest RMSE: {(test_RMSE/len(test_dl))}')