<a href="https://colab.research.google.com/github/routb68/Zummit_Infolab/blob/main/crowd_26Dec.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
import pdb
import numpy as np

__all__ = ['Inception3', 'inception_v3']


model_urls = {
    # Inception v3 ported from TensorFlow
    'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
}


def headCount_inceptionv3(pretrained=False, **kwargs):
    r"""Inception v3 model architecture from
    `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    if pretrained:
        if 'transform_input' not in kwargs:
            kwargs['transform_input'] = True
        model = Inception3(**kwargs)
        model.load_state_dict(model_zoo.load_url(model_urls['inception_v3_google']),strict=False)
        return model

    return Inception3(**kwargs)


class Inception3(nn.Module):

    def __init__(self, num_classes=1000, aux_logits=False, transform_input=False):
        super(Inception3, self).__init__()
        self.aux_logits = aux_logits
        self.transform_input = transform_input
        self.Conv2d_1a_3x3 = BasicConv2d(3, 32, kernel_size=3, stride=2, padding=1)
        self.Conv2d_2a_3x3 = BasicConv2d(32, 32, kernel_size=3, padding=1)
        self.Conv2d_2b_3x3 = BasicConv2d(32, 64, kernel_size=3, padding=1)
        self.Conv2d_3b_1x1 = BasicConv2d(64, 80, kernel_size=1)
        self.Conv2d_4a_3x3 = BasicConv2d(80, 192, kernel_size=3, padding=1)
        self.Mixed_5b = InceptionA(192, pool_features=32)
        self.Mixed_5c = InceptionA(256, pool_features=64)
        self.Mixed_5d = InceptionA(288, pool_features=64)
        self.Mixed_6a = InceptionB(288)
        self.Mixed_6b = InceptionC(768, channels_7x7=128)
        self.Mixed_6c = InceptionC(768, channels_7x7=160)
        self.Mixed_6d = InceptionC(768, channels_7x7=160)
        self.Mixed_6e = InceptionC(768, channels_7x7=192)
        if aux_logits:
            self.AuxLogits = InceptionAux(768, num_classes)
        self.Mixed_7a = InceptionD(768)
        self.Mixed_7b = InceptionE(1280)
        self.Mixed_7c = InceptionE(2048)

        self.relu = nn.ReLU(inplace=True)
        self.sigm = nn.Sigmoid()
        self.lconv1 = nn.Conv2d(288, 1, kernel_size = 1, stride=1, padding=0, bias=False)
        self.lconv2 = nn.Conv2d(768, 1, kernel_size = 1, stride=1, padding=0, bias=False)
        self.lconv3 = nn.Conv2d(2048, 1, kernel_size = 1, stride=1, padding=0, bias=False)
        self.att_conv = nn.Conv2d(2048, 1, kernel_size = 1, stride=1, padding=0, bias=False)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                import scipy.stats as stats
                stddev = m.stddev if hasattr(m, 'stddev') else 0.1
                X = stats.truncnorm(-2, 2, scale=stddev)
                values = torch.Tensor(X.rvs(m.weight.numel()))
                values = values.view(m.weight.size())
                m.weight.data.copy_(values)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        if self.transform_input:
            x = x.clone()
            x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
        # 299 x 299 x 3
        x = self.Conv2d_1a_3x3(x)
        # 149 x 149 x 32
        x = self.Conv2d_2a_3x3(x)
        # 147 x 147 x 32
        x = self.Conv2d_2b_3x3(x)
        # 147 x 147 x 64
        # x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        # 73 x 73 x 64
        x = self.Conv2d_3b_1x1(x)
        # 73 x 73 x 80
        x = self.Conv2d_4a_3x3(x)
        # 71 x 71 x 192
        # x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        # 35 x 35 x 192
        x = self.Mixed_5b(x)
        # 35 x 35 x 256
        x = self.Mixed_5c(x)
        # 35 x 35 x 288
        x = self.Mixed_5d(x)
        # 35 x 35 x 288
        # 128x128x288
        x = self.Mixed_6a(x)
        # 17 x 17 x 768
        x = self.Mixed_6b(x)
        # 17 x 17 x 768
        x = self.Mixed_6c(x)
        # 17 x 17 x 768
        x = self.Mixed_6d(x)
        # 17 x 17 x 768
        x = self.Mixed_6e(x)
        # 64x64x768
        # 17 x 17 x 768
                
        if self.training and self.aux_logits:
            aux = self.AuxLogits(x)
        # 17 x 17 x 768
        x = self.Mixed_7a(x)
        # 8 x 8 x 1280
        x = self.Mixed_7b(x)
        # 8 x 8 x 2048
        x = self.upsample(x)
        attention_map = self.sigm(self.att_conv(x))
        feature_map3 = self.Mixed_7c(x)
        feature_map3 = feature_map3*attention_map
        # 32x32x2048
        #feature_map4 = F.avg_pool2d(feature_map3,2)
        # x_cat = feature_map3
        #density_map1 = self.lconv1(feature_map1)
        #density_map1 = density_map1.view(-1,density_map1.size(2),density_map1.size(3))
        #density_map2 = self.lconv2(feature_map2)
        #density_map2 = density_map2.view(-1,density_map2.size(2),density_map2.size(3))
        density_map3 = self.lconv3(feature_map3)
        density_map3 = self.relu(density_map3)
        density_map3 = density_map3.view(-1,density_map3.size(2),density_map3.size(3))
        attention_map = attention_map.view(-1,attention_map.size(2),attention_map.size(3))
        
        #density_map4 = self.lconv4(feature_map4)
        #density_map4 = density_map4.view(-1,density_map4.size(2),density_map4.size(3))
        # density_map = F.avg_pool2d(density_map,kernel_size=2)
        return density_map3,attention_map

class SequenceWise(nn.Module):
    def __init__(self, module):
        """
        Collapses input of dim T*N*H to (T*N)*H, and applies to a module.
        Allows handling of variable sequence lengths and minibatch sizes.
        :param module: Module to apply input to.
        """
        super(SequenceWise, self).__init__()
        self.module = module

    def forward(self, x):
        t, n = x.size(0), x.size(1)
        x = x.contiguous().view(t * n, -1)
        x = self.module(x)
        x = x.view(t, n, -1)
        x = x.permute(1,0,2)
        return x

    def __repr__(self):
        tmpstr = self.__class__.__name__ + ' (\n'
        tmpstr += self.module.__repr__()
        tmpstr += ')'
        return tmpstr

class InceptionA(nn.Module):

    def __init__(self, in_channels, pool_features):
        super(InceptionA, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 64, kernel_size=1)

        self.branch5x5_1 = BasicConv2d(in_channels, 48, kernel_size=1)
        self.branch5x5_2 = BasicConv2d(48, 64, kernel_size=5, padding=2)

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, padding=1)

        self.branch_pool = BasicConv2d(in_channels, pool_features, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_1(x)
        branch5x5 = self.branch5x5_2(branch5x5)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionB(nn.Module):

    def __init__(self, in_channels):
        super(InceptionB, self).__init__()
        self.branch3x3 = BasicConv2d(in_channels, 384, kernel_size=3, stride=2, padding=1)

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 64, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(64, 96, kernel_size=3, padding=1)
        self.branch3x3dbl_3 = BasicConv2d(96, 96, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        branch3x3 = self.branch3x3(x)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)

        outputs = [branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionC(nn.Module):

    def __init__(self, in_channels, channels_7x7):
        super(InceptionC, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 192, kernel_size=1)

        c7 = channels_7x7
        self.branch7x7_1 = BasicConv2d(in_channels, c7, kernel_size=1)
        self.branch7x7_2 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7_3 = BasicConv2d(c7, 192, kernel_size=(7, 1), padding=(3, 0))

        self.branch7x7dbl_1 = BasicConv2d(in_channels, c7, kernel_size=1)
        self.branch7x7dbl_2 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_3 = BasicConv2d(c7, c7, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7dbl_4 = BasicConv2d(c7, c7, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7dbl_5 = BasicConv2d(c7, 192, kernel_size=(1, 7), padding=(0, 3))

        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch7x7 = self.branch7x7_1(x)
        branch7x7 = self.branch7x7_2(branch7x7)
        branch7x7 = self.branch7x7_3(branch7x7)

        branch7x7dbl = self.branch7x7dbl_1(x)
        branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
        branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionD(nn.Module):

    def __init__(self, in_channels):
        super(InceptionD, self).__init__()
        self.branch3x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.branch3x3_2 = BasicConv2d(192, 320, kernel_size=3, stride=2, padding=1)

        self.branch7x7x3_1 = BasicConv2d(in_channels, 192, kernel_size=1)
        self.branch7x7x3_2 = BasicConv2d(192, 192, kernel_size=(1, 7), padding=(0, 3))
        self.branch7x7x3_3 = BasicConv2d(192, 192, kernel_size=(7, 1), padding=(3, 0))
        self.branch7x7x3_4 = BasicConv2d(192, 192, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        branch3x3 = self.branch3x3_1(x)
        branch3x3 = self.branch3x3_2(branch3x3)

        branch7x7x3 = self.branch7x7x3_1(x)
        branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
        branch7x7x3 = self.branch7x7x3_4(branch7x7x3)

        branch_pool = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
        outputs = [branch3x3, branch7x7x3, branch_pool]
        return torch.cat(outputs, 1)


class InceptionE(nn.Module):

    def __init__(self, in_channels):
        super(InceptionE, self).__init__()
        self.branch1x1 = BasicConv2d(in_channels, 320, kernel_size=1)

        self.branch3x3_1 = BasicConv2d(in_channels, 384, kernel_size=1)
        self.branch3x3_2a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3_2b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch3x3dbl_1 = BasicConv2d(in_channels, 448, kernel_size=1)
        self.branch3x3dbl_2 = BasicConv2d(448, 384, kernel_size=3, padding=1)
        self.branch3x3dbl_3a = BasicConv2d(384, 384, kernel_size=(1, 3), padding=(0, 1))
        self.branch3x3dbl_3b = BasicConv2d(384, 384, kernel_size=(3, 1), padding=(1, 0))

        self.branch_pool = BasicConv2d(in_channels, 192, kernel_size=1)

    def forward(self, x):
        branch1x1 = self.branch1x1(x)

        branch3x3 = self.branch3x3_1(x)
        branch3x3 = [
            self.branch3x3_2a(branch3x3),
            self.branch3x3_2b(branch3x3),
        ]
        branch3x3 = torch.cat(branch3x3, 1)

        branch3x3dbl = self.branch3x3dbl_1(x)
        branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
        branch3x3dbl = [
            self.branch3x3dbl_3a(branch3x3dbl),
            self.branch3x3dbl_3b(branch3x3dbl),
        ]
        branch3x3dbl = torch.cat(branch3x3dbl, 1)

        branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1)
        branch_pool = self.branch_pool(branch_pool)

        outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
        return torch.cat(outputs, 1)


class InceptionAux(nn.Module):

    def __init__(self, in_channels, num_classes):
        super(InceptionAux, self).__init__()
        self.conv0 = BasicConv2d(in_channels, 128, kernel_size=1)
        self.conv1 = BasicConv2d(128, 768, kernel_size=5)
        self.conv1.stddev = 0.01
        self.fc = nn.Linear(768, num_classes)
        self.fc.stddev = 0.001

    def forward(self, x):
        # 17 x 17 x 768
        x = F.avg_pool2d(x, kernel_size=5, stride=3)
        # 5 x 5 x 768
        x = self.conv0(x)
        # 5 x 5 x 128
        x = self.conv1(x)
        # 1 x 1 x 768
        x = x.view(x.size(0), -1)
        # 768
        x = self.fc(x)
        # 1000
        return x


class BasicConv2d(nn.Module):

    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels, eps=0.001)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return F.relu(x, inplace=True)

In [3]:
import numpy as np
import scipy.spatial
import pdb
def generate_multi_density_map(shape=(5,5),points=None,f_sz=15,sigma=4,num=3):
    '''
    generate multiple density maps according to the density
    '''
    # calculate the distance of each point to the nearest neighbour
    dist = scipy.spatial.distance.cdist(points,points,metric='euclidean')
    dist.sort()
    k = 3
    f_sz_vec = [15,15,15]
    meanDist = dist[:,1:k+1].mean(axis=1)
    thresholds = np.array([[0,20],[20,50],[50,1e9]])
    density_map = np.zeros((num,shape[0],shape[1]))
    for i in range(num):
        selector = (meanDist>thresholds[i,0]) & (meanDist<=thresholds[i,1])
        points_subset = points[selector,:]
        density_map[i,] = generate_density_map(shape,points_subset,f_sz_vec[i],sigma)
    return density_map
        
def generate_density_map(shape=(5,5),points=None,f_sz=15,sigma=4):
    """
    generate density map given head coordinations
    """
    im_density = np.zeros(shape[0:2])
    h, w = shape[0:2]
    if len(points) == 0:
        return im_density
    for j in range(len(points)):
        H = matlab_style_gauss2D((f_sz,f_sz),sigma)
        x = np.minimum(w,np.maximum(1,np.abs(np.int32(np.floor(points[j,0])))))
        y = np.minimum(h,np.maximum(1,np.abs(np.int32(np.floor(points[j,1])))))
        if x>w or y>h:
            continue
        x1 = x - np.int32(np.floor(f_sz/2))
        y1 = y - np.int32(np.floor(f_sz/2))
        x2 = x + np.int32(np.floor(f_sz/2))
        y2 = y + np.int32(np.floor(f_sz/2))
        dfx1 = 0
        dfy1 = 0
        dfx2 = 0
        dfy2 = 0
        change_H = False
        if x1 < 1:
            dfx1 = np.abs(x1)+1
            x1 = 1
            change_H = True
        if y1 < 1:
            dfy1 = np.abs(y1)+1
            y1 = 1
            change_H = True
        if x2 > w:
            dfx2 = x2 - w
            x2 = w
            change_H = True
        if y2 > h:
            dfy2 = y2 - h
            y2 = h
            change_H = True
        x1h = 1+dfx1
        y1h = 1+dfy1
        x2h = f_sz - dfx2
        y2h = f_sz - dfy2
        if change_H:
            H =  matlab_style_gauss2D((y2h-y1h+1,x2h-x1h+1),sigma)
        im_density[y1-1:y2,x1-1:x2] = im_density[y1-1:y2,x1-1:x2] +  H;
    return im_density
     
def matlab_style_gauss2D(shape=(3,3),sigma=0.5):
    """
    2D gaussian mask - should give the same result as MATLAB's
    fspecial('gaussian',[shape],[sigma])
    """
    m,n = [(ss-1.)/2. for ss in shape]
    y,x = np.ogrid[-m:m+1,-n:n+1]
    h = np.exp( -(x*x + y*y) / (2.*sigma*sigma) )
    h[ h < np.finfo(h.dtype).eps*h.max() ] = 0
    sumh = h.sum()
    if sumh != 0:
        h /= sumh
    return h


In [4]:
# -*- coding: utf-8 -*-
"""
==========================
**Author**: Qian Wang, qian.wang173@hotmail.com
"""


from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import pandas as pd
from skimage import io, transform
import torch.nn.functional as F
import cv2
import skimage.measure
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES=True
import scipy
import scipy.io
import pdb
plt.ion()   # interactive mode
'''
from model import CANNet
from model_mcnn import MCNN
from model_cffnet import CFFNet
from model_csrnet import CSRNet
from model_sanet import SANet
from model_tednet import TEDNet
from myInception_segLoss import headCount_inceptionv3
from generate_density_map import generate_multi_density_map,generate_density_map
'''

IMG_EXTENSIONS = ['.JPG','.JPEG','.jpg', '.jpeg', '.PNG', '.png', '.ppm', '.bmp', '.pgm', '.tif']
def has_file_allowed_extension(filename, extensions):
    """Checks if a file is an allowed extension.

    Args:
        filename (string): path to a file

    Returns:
        bool: True if the filename ends with a known image extension
    """
    filename_lower = filename.lower()
    return any(filename_lower.endswith(ext) for ext in extensions)

def make_dataset(dir, extensions):
    images = []
    dir = os.path.expanduser(dir)
    """
    for target in sorted(os.listdir(dir)):
        d = os.path.join(dir, target)
        if not os.path.isdir(d):
            continue
    """
    d = os.path.join(dir,'images')
    for root, _, fnames in sorted(os.walk(d)):
            for fname in sorted(fnames):
                if has_file_allowed_extension(fname, extensions):
                    image_path = os.path.join(root, fname)
                    head,tail = os.path.split(root)
                    label_path = os.path.join(head,'ground_truth','GT_'+fname[:-4]+'.mat')
                    item = [image_path, label_path]
                    images.append(item)

    return images

class ShanghaiTechDataset(Dataset):
    def __init__(self, data_dir, transform=None, phase='train',extensions=IMG_EXTENSIONS,patch_size=128,num_patches_per_image=4):
        self.samples = make_dataset(data_dir,extensions)
        self.image_dir = data_dir
        self.transform = transform
        self.phase = phase
        self.patch_size = patch_size
        self.numPatches = num_patches_per_image
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self,idx):        
        img_file,label_file = self.samples[idx]
        image = cv2.imread(img_file)
        height, width, channel = image.shape
        annPoints = scipy.io.loadmat(label_file)
        annPoints = annPoints['image_info'][0][0][0][0][0]
        positions = generate_density_map(shape=image.shape,points=annPoints,f_sz=15,sigma=4)
        fbs = generate_density_map(shape=image.shape,points=annPoints,f_sz=25,sigma=1)
        fbs = np.int32(fbs>0)
        targetSize = [self.patch_size,self.patch_size]
        height, width, channel = image.shape
        if height < targetSize[0] or width < targetSize[1]:
            image = cv2.resize(image,(np.maximum(targetSize[0]+2,height),np.maximum(targetSize[1]+2,width)))
            count = positions.sum()
            max_value = positions.max()
            # down density map
            positions = cv2.resize(positions, (np.maximum(targetSize[0]+2,height),np.maximum(targetSize[1]+2,width)))
            count2 = positions.sum()
            positions = np.minimum(positions*count/(count2+1e-8),max_value*10)
            fbs = cv2.resize(fbs,(np.maximum(targetSize[0]+2,height),np.maximum(targetSize[1]+2,width)))
            fbs = np.int32(fbs>0)
        if len(image.shape)==2:
            image = np.expand_dims(image,2)
            image = np.concatenate((image,image,image),axis=2)
        # transpose from h x w x channel to channel x h x w
        image = image.transpose(2,0,1)
        numPatches = self.numPatches
        if self.phase == 'train':
            patchSet, countSet, fbsSet = getRandomPatchesFromImage(image,positions,fbs,targetSize,numPatches)
            x = np.zeros((patchSet.shape[0],3,targetSize[0],targetSize[1]))
            if self.transform:
              for i in range(patchSet.shape[0]):
                #transpose to original:h x w x channel
                x[i,:,:,:] = self.transform(np.uint8(patchSet[i,:,:,:]).transpose(1,2,0))
            patchSet = x
        if self.phase == 'val' or self.phase == 'test':
            patchSet, countSet, fbsSet = getAllFromImage(image, positions, fbs)
            patchSet[0,:,:,:] = self.transform(np.uint8(patchSet[0,:,:,:]).transpose(1,2,0))
        return patchSet, countSet, fbsSet

def getRandomPatchesFromImage(image,positions,fbs,target_size,numPatches):
    # generate random cropped patches with pre-defined size, e.g., 224x224
    imageShape = image.shape
    if np.random.random()>0.5:
        for channel in range(3):
            image[channel,:,:] = np.fliplr(image[channel,:,:])
        positions = np.fliplr(positions)
        fbs = np.fliplr(fbs)
    patchSet = np.zeros((numPatches,3,target_size[0],target_size[1]))
    # generate density map
    countSet = np.zeros((numPatches,1,target_size[0],target_size[1]))
    fbsSet = np.zeros((numPatches,1,target_size[0],target_size[1]))
    for i in range(numPatches):
        topLeftX = np.random.randint(imageShape[1]-target_size[0]+1)#x-height
        topLeftY = np.random.randint(imageShape[2]-target_size[1]+1)#y-width
        thisPatch = image[:,topLeftX:topLeftX+target_size[0],topLeftY:topLeftY+target_size[1]]
        patchSet[i,:,:,:] = thisPatch
        # density map
        position = positions[topLeftX:topLeftX+target_size[0],topLeftY:topLeftY+target_size[1]]
        fb = fbs[topLeftX:topLeftX+target_size[0],topLeftY:topLeftY+target_size[1]]
        position = position.reshape((1, position.shape[0], position.shape[1]))
        fb = fb.reshape((1, fb.shape[0], fb.shape[1]))
        countSet[i,:,:,:] = position
        fbsSet[i,:,:,:] = fb
    return patchSet, countSet, fbsSet

def getAllPatchesFromImage(image,positions,target_size):
    # generate all patches from an image for prediction
    nchannel,height,width = image.shape
    nRow = np.int(height/target_size[1])
    nCol = np.int(width/target_size[0])
    target_size[1] = np.int(height/nRow)
    target_size[0] = np.int(width/nCol)
    patchSet = np.zeros((nRow*nCol,3,target_size[1],target_size[0]))
    for i in range(nRow):
      for j in range(nCol):
        patchSet[i*nCol+j,:,:,:] = image[:,i*target_size[1]:(i+1)*target_size[1], j*target_size[0]:(j+1)*target_size[0]]
    return patchSet

def getAllFromImage(image,positions,fbs):
    nchannel, height, width = image.shape
    patchSet =np.zeros((1,3,height, width))
    patchSet[0,:,:,:] = image[:,:,:]
    countSet = positions.reshape((1,1,positions.shape[0], positions.shape[1]))
    fbsSet = fbs.reshape((1,1,fbs.shape[0], fbs.shape[1]))
    return patchSet, countSet, fbsSet

data_transforms = {
    'train': transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}
class SubsetSampler(torch.utils.data.sampler.Sampler):
    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in range(len(self.indices)))

    def __len__(self):
        return len(self.indices)


def train_model(model, optimizer, scheduler, num_epochs=100, seg_loss=False, cl_loss=False, test_step=10):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_mae_val = 1e6
    best_mae_by_val = 1e6
    best_mae_by_test = 1e6
    best_mse_by_val = 1e6
    best_mse_by_test = 1e6
    criterion1 = nn.MSELoss(reduce=False) # for density map loss
    criterion2 = nn.BCELoss() # for segmentation map loss
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        model.train()  # Set model to training mode
        running_loss = 0.0        
        # Iterate over data.
        for index, (inputs, labels, fbs) in enumerate(dataloaders['train']):
            labels = labels*100
            labels = skimage.measure.block_reduce(labels.cpu().numpy(),(1,1,1,4,4),np.sum)
            fbs = skimage.measure.block_reduce(fbs.cpu().numpy(),(1,1,1,4,4),np.max)
            fbs = np.float32(fbs>0)
            labels = torch.from_numpy(labels)
            fbs = torch.from_numpy(fbs)
            labels = labels.to(device)
            fbs = fbs.to(device)
            inputs = inputs.to(device)
            inputs = inputs.view(-1,inputs.shape[2],inputs.shape[3],inputs.shape[4])
            labels = labels.view(-1,labels.shape[3],labels.shape[4])
            fbs = fbs.view(-1,fbs.shape[3],fbs.shape[4])
            inputs = inputs.float()
            labels = labels.float()
            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            # track history if only in train
            with torch.set_grad_enabled(True):
                output,fbs_out = model(inputs)
                loss_den = criterion1(output, labels)
                loss_seg = criterion2(fbs_out, fbs)
                if cl_loss:
                    th = 0.1*epoch+5 #cl2
                else:
                    th=1000 # no curriculum loss when th is set a big number
                weights = th/(F.relu(labels-th)+th)
                loss_den = loss_den*weights
                loss_den = loss_den.sum()/weights.sum()
                if seg_loss:
                    loss = loss_den + 20*loss_seg
                else:
                    loss = loss_den

                loss.backward()
                optimizer.step()
            running_loss += loss.item() * inputs.size(0)
               
        scheduler.step()    
        epoch_loss = running_loss / dataset_sizes['train']            
        
        print('Train Loss: {:.4f}'.format(epoch_loss))
        print()
        if epoch%test_step==0:
            tmp,epoch_mae,epoch_mse,epoch_mre=test_model(model,optimizer,'val')
            tmp,epoch_mae_test,epoch_mse_test,epoch_mre_test = test_model(model,optimizer,'test')
            if  epoch_mae < best_mae_val:
                best_mae_val = epoch_mae
                best_mae_by_val = epoch_mae_test
                best_mse_by_val = epoch_mse_test
                best_epoch_val = epoch
                best_model_wts = copy.deepcopy(model.state_dict())
            if epoch_mae_test < best_mae_by_test:
                best_mae_by_test = epoch_mae_test
                best_mse_by_test = epoch_mse_test
                best_epoch_test = epoch
            print()
            print('best MAE and MSE by val:  {:2.2f} and {:2.2f} at Epoch {}'.format(best_mae_by_val,best_mse_by_val, best_epoch_val))
            print('best MAE and MSE by test: {:2.2f} and {:2.2f} at Epoch {}'.format(best_mae_by_test,best_mse_by_test, best_epoch_test))
    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


def test_model(model,optimizer,phase):
    since = time.time()
    model.eval()
    mae = 0
    mse = 0
    mre = 0
    pred = np.zeros((3000,2))
    # Iterate over data.
    for index, (inputs, labels, fbs) in enumerate(dataloaders[phase]):
        inputs = inputs.to(device)
        labels = labels.to(device)
        inputs = inputs.float()
        labels = labels.float()
        inputs = inputs.view(-1,inputs.shape[2],inputs.shape[3],inputs.shape[4])
        labels = labels.view(-1,labels.shape[3],labels.shape[4])
        # zero the parameter gradients
        optimizer.zero_grad()
        # forward
        with torch.set_grad_enabled(False):
            outputs,fbs_out = model(inputs)
            outputs = outputs.to(torch.device("cpu")).numpy()/100
            pred_count = outputs.sum()
        true_count = labels.to(torch.device("cpu")).numpy().sum()
        # backward + optimize only if in training phase
        mse = mse + np.square(pred_count-true_count)
        mae = mae + np.abs(pred_count-true_count)
        mre = mre + np.abs(pred_count-true_count)/true_count
        pred[index,0] = pred_count
        pred[index,1] = true_count
    pred = pred[0:index+1,:]
    mse = np.sqrt(mse/(index+1))
    mae = mae/(index+1)
    mre = mre/(index+1)
    print(phase+':')
    print(f'MAE:{mae:2.2f}, RMSE:{mse:2.2f}, MRE:{mre:2.4f}')
    time_elapsed = time.time() - since
    return pred,mae,mse,mre

#####################################################################
# set parameters here
seg_loss = True
cl_loss = True
test_step = 1
batch_size = 6
num_workers = 4
patch_size = 128
num_patches_per_image = 4
data_dir = '/content/drive/MyDrive/testing/'

# define data set
image_datasets = {x: ShanghaiTechDataset(data_dir+x+'_data', 
                        phase=x, 
                        transform=data_transforms[x],
                        patch_size=patch_size,
                        num_patches_per_image=num_patches_per_image)
                    for x in ['train','test']}
image_datasets['val'] = ShanghaiTechDataset(data_dir+'train_data',
                            phase='val',
                            transform=data_transforms['val'],
                            patch_size=patch_size,
                            num_patches_per_image=num_patches_per_image)
## split the data into train/validation/test subsets
indices = list(range(len(image_datasets['train'])))
split = np.int(len(image_datasets['train'])*0.2)

val_idx = np.random.choice(indices, size=split, replace=False)
train_idx = indices#list(set(indices)-set(val_idx))
test_idx = range(len(image_datasets['test']))

train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
test_sampler = SubsetSampler(test_idx)

train_loader = torch.utils.data.DataLoader(dataset=image_datasets['train'],batch_size=batch_size,sampler=train_sampler, num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(dataset=image_datasets['val'],batch_size=1,sampler=val_sampler, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(dataset=image_datasets['test'],batch_size=1,sampler=test_sampler, num_workers=num_workers)

dataset_sizes = {'train':len(train_idx),'val':len(val_idx),'test':len(image_datasets['test'])}
dataloaders = {'train':train_loader,'val':val_loader,'test':test_loader}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

########################################################################
# define models and training
model = headCount_inceptionv3(pretrained=True)
# model = MCNN()
# model = SANet()
# model = TEDNet(use_bn=True)
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

model = train_model(model, optimizer, exp_lr_scheduler,
                    num_epochs=5, #501
                    seg_loss=seg_loss, 
                    cl_loss=cl_loss, 
                    test_step=test_step)
                    
pred,mae,mse,mre = test_model(model,optimizer,'test')
scipy.io.savemat('./results.mat', mdict={'pred': pred, 'mse': mse, 'mae': mae,'mre': mre})
model_dir = './'
torch.save(model.state_dict(), model_dir+'saved_model.pt')


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  split = np.int(len(image_datasets['train'])*0.2)
Downloading: "https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth" to /root/.cache/torch/hub/checkpoints/inception_v3_google-1a9a5a14.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

Epoch 0/4
----------




Train Loss: 70.1457

val:
MAE:5.38, RMSE:5.38, MRE:0.0543
test:
MAE:32.17, RMSE:38.14, MRE:0.9540

best MAE and MSE by val:  32.17 and 38.14 at Epoch 0
best MAE and MSE by test: 32.17 and 38.14 at Epoch 0
Epoch 1/4
----------
Train Loss: 52.9634

val:
MAE:34.50, RMSE:34.50, MRE:0.3485
test:
MAE:69.89, RMSE:72.18, MRE:1.6574

best MAE and MSE by val:  32.17 and 38.14 at Epoch 0
best MAE and MSE by test: 32.17 and 38.14 at Epoch 0
Epoch 2/4
----------
Train Loss: 39.3105

val:
MAE:68.38, RMSE:68.38, MRE:0.6908
test:
MAE:106.11, RMSE:107.74, MRE:2.3753

best MAE and MSE by val:  32.17 and 38.14 at Epoch 0
best MAE and MSE by test: 32.17 and 38.14 at Epoch 0
Epoch 3/4
----------
Train Loss: 28.9451

val:
MAE:54.72, RMSE:54.72, MRE:0.5527
test:
MAE:91.76, RMSE:93.80, MRE:2.0692

best MAE and MSE by val:  32.17 and 38.14 at Epoch 0
best MAE and MSE by test: 32.17 and 38.14 at Epoch 0
Epoch 4/4
----------
Train Loss: 32.3184

val:
MAE:29.26, RMSE:29.26, MRE:0.2956
test:
MAE:66.73, RMSE:69.27,