##**Structured Knowledge Transfer (SKT)**

###**Utils**

In [None]:
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
import shutil
import os
import time

In [None]:
def save_net(fname, net):
    with h5py.File(fname, 'w') as h5f:
        for k, v in net.state_dict().items():
            h5f.create_dataset(k, data=v.cpu().numpy())


def load_net(fname, net):
    with h5py.File(fname, 'r') as h5f:
        for k, v in net.state_dict().items():        
            param = torch.from_numpy(np.asarray(h5f[k]))         
            v.copy_(param)            
path="/content/drive/MyDrive/ShanghaiTech_Crowd_Counting_Dataset/CSRNet_models_weights/checkpoint.pth.tar"
def save_checkpoint(state, mae_is_best, mse_is_best, path, filename='checkpoint.pth.tar'):
    torch.save(state, os.path.join(path, filename))
    epoch = state['epoch']
    if mae_is_best:
        shutil.copyfile(os.path.join(path, filename), os.path.join(path, 'epoch'+str(epoch)+'_best_mae.pth.tar'))
    if mse_is_best:
        shutil.copyfile(os.path.join(path, filename), os.path.join(path, 'epoch'+str(epoch)+'_best_mse.pth.tar'))


def cal_para(net):
    params = list(net.parameters())
    k = 0
    for i in params:
        l = 1
        # print "stucture of layer: " + str(list(i.size()))
        for j in i.size():
            l *= j
        # print "para in this layer: " + str(l)
        k = k + l
    print("the amount of para: " + str(k))


def crop_img_patches(img, size=512):
    """ crop the test images to patches
    while testing UCF data, we load original images, then use crop_img_patches to crop the test images to patches,
    calculate the crowd count respectively and sum them together finally
    """
    w = img.shape[3]
    h = img.shape[2]
    x = int(w/size)+1
    y = int(h/size)+1
    crop_w = int(w/x)
    crop_h = int(h/y)
    patches = []
    for i in range(x):
        for j in range(y):
            start_x = crop_w*i
            if i == x-1:
                end_x = w
            else:
                end_x = crop_w*(i+1)

            start_y = crop_h*j
            if j == y - 1:
                end_y = h
            else:
                end_y = crop_h*(j+1)

            sub_img = img[:, :, start_y:end_y, start_x:end_x]
            patches.append(sub_img)
    return patches

###**Image**

In [None]:
import random
import os
from PIL import Image,ImageFilter,ImageDraw
import numpy as np
import h5py
from PIL import ImageStat
import cv2
import time

In [None]:
def load_data(img_path,train=True, dataset='shanghai'):
    """ Load data
    Use crop_ratio between 0.5 and 1.0 for random crop
    """
    gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground_truth')
    img = Image.open(img_path).convert('RGB')
    gt_file = h5py.File(gt_path)
    target = np.asarray(gt_file['density'])
    if train:
        if dataset == 'shanghai':
            crop_ratio = random.uniform(0.5, 1.0)
            crop_size = (int(crop_ratio*img.size[0]), int(crop_ratio*img.size[1]))
            dx = int(random.random() * (img.size[0]-crop_size[0]))
            dy = int(random.random() * (img.size[1]-crop_size[1]))

            img = img.crop((dx,dy,crop_size[0]+dx,crop_size[1]+dy))
            target = target[dy:crop_size[1]+dy,dx:crop_size[0]+dx]

        if random.random() > 0.8:
            target = np.fliplr(target)
            img = img.transpose(Image.FLIP_LEFT_RIGHT)

    target = reshape_target(target, 3)
    target = np.expand_dims(target, axis=0)

    img = img.copy()
    target = target.copy()
    return img, target


def load_ucf_ori_data(img_path):
    """ Load original UCF-QNRF data for testing
    """
    gt_path = img_path.replace('.jpg', '.h5').replace('images', 'ground_truth')
    img = Image.open(img_path).convert('RGB')
    gt_file = h5py.File(gt_path)
    target = np.asarray(gt_file['density'])
    return img, target


def reshape_target(target, down_sample=3):
    """ Down sample GT to 1/8
    """
    height = target.shape[0]
    width = target.shape[1]

    # ceil_mode=True for nn.MaxPool2d in model
    for i in range(down_sample):
        height = int((height+1)/2)
        width = int((width+1)/2)
        # height = int(height/2)
        # width = int(width/2)

    target = cv2.resize(target, (width, height), interpolation=cv2.INTER_CUBIC) * (2**(down_sample*2))
    return target

###**Dataset**

In [None]:
def load_data(img_path,train = True):
    gt_path = img_path.replace('.jpg','.npy').replace('images','ground_truth')
    img = Image.open(img_path).convert('RGB')
    target = np.load(gt_path)
    if False:
        crop_size = (img.size[0]/2,img.size[1]/2)
        if random.randint(0,9)<= -1:
            dx = int(random.randint(0,1)*img.size[0]*1./2)
            dy = int(random.randint(0,1)*img.size[1]*1./2)
        else:
            dx = int(random.random()*img.size[0]*1./2)
            dy = int(random.random()*img.size[1]*1./2)
        img = img.crop((dx,dy,crop_size[0]+dx,crop_size[1]+dy))
        target = target[dy:crop_size[1]+dy,dx:crop_size[0]+dx]

        if random.random()>0.8:
            target = np.fliplr(target)
            img = img.transpose(Image.FLIP_LEFT_RIGHT)
    return img,target

In [None]:
import os
import random
import torch
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
#from image import *
import torchvision.transforms.functional as F


class listDataset(Dataset):
    def __init__(self, root, shape=None, transform=None, shuffle=True,  train=False, seen=0,
                 batch_size=1, num_workers=0, dataset='shanghai'):
        if train and dataset == 'shanghai':
            root = root*4
        random.shuffle(root)
        
        self.nSamples = len(root)
        self.lines = root
        self.transform = transform
        self.train = train
        self.shape = shape
        self.seen = seen
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.dataset = dataset

    def __len__(self):
        return self.nSamples

    def __getitem__(self, index):
        assert index <= len(self), 'index range error' 
        
        img_path = self.lines[index]

        if self.dataset == 'ucf_test':
            # test in UCF
            img, target = load_ucf_ori_data(img_path)
        else:
            img, target = load_data(img_path, self.train)

        if self.transform is not None:
            img = self.transform(img)

            ratio = target.shape[0]/768
            target = cv2.resize(target,(768,768),interpolation = cv2.INTER_CUBIC)*(ratio**2)
            target = cv2.resize(target,(int(target.shape[1]/8),int(target.shape[0]/8)),interpolation = cv2.INTER_CUBIC)*64
            target = torch.tensor(target, dtype=torch.float32).unsqueeze(0)
        return img, target

In [None]:
train_json = "/content/drive/MyDrive/ShanghaiTech_Crowd_Counting_Dataset/part_A_final/json/part_A_train.json"
with open(train_json, 'r') as outfile:
        train_list = json.load(outfile)
workers = 0
batch_size = 6

###**Preprocess**

####**make_json**

In [None]:
"""
Make json files for dataset
"""
import json
import os


def get_val(root):
    """
    Validation set follows part_A_val.json in CSRNet
    https://github.com/leeyeehoo/CSRNet-pytorch
    """
    with open("/content/drive/MyDrive/ShanghaiTech_Crowd_Counting_Dataset/part_A_final/json/part_A_val.json") as f:
        val_list = json.load(f)
    new_val = []
    for item in val_list:
        new_item = item.replace('/home/leeyh/Downloads/Shanghai/', root)
        new_val.append(new_item)
    with open('A_val.json', 'w') as f:
        json.dump(new_val, f)


def get_train(root):
    path = os.path.join(root, 'part_A_final', 'train_data', 'images')
    filenames = os.listdir(path)
    pathname = [os.path.join(path, filename) for filename in filenames]
    with open('A_train.json', 'w') as f:
        json.dump(pathname, f)


def get_test(root):
    path = os.path.join(root, 'part_A_final', 'test_data', 'images')
    filenames = os.listdir(path)
    pathname = [os.path.join(path, filename) for filename in filenames]
    with open('A_test.json', 'w') as f:
        json.dump(pathname, f)


root = '/content/drive/MyDrive/ShanghaiTech_Crowd_Counting_Dataset'  # Dataset path
get_train(root)
get_val(root)
get_test(root)
print('Finish!')

Finish!
