In [145]:
import argparse
import time
import torch.optim
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict
import torch.utils.data as data
import torch.backends.cudnn as cudnn
import sys
import numpy as np
import json
import resnet
import math

from PIL import Image
import torch.backends.cudnn as cudnn
from tqdm import tqdm

import cv2

In [78]:
class AtrousModule(nn.Module):
    def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm):
        super(AtrousModule, self).__init__()
        self.conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size,
                                            stride=1, padding=padding, dilation=dilation, bias=False)
        self.batch = BatchNorm(planes)
        self.relu = nn.ReLU()

        self._init_weight()

    def forward(self, x):
        x = self.conv(x)
        x = self.batch(x)

        return self.relu(x)

    def _init_weight(self):
        for mod in self.modules():
            if isinstance(mod, nn.Conv2d):
                torch.nn.init.kaiming_normal_(mod.weight)
            elif isinstance(mod, nn.BatchNorm2d):
                mod.weight.data.fill_(1)
                mod.bias.data.zero_()
            elif isinstance(mod, nn.BatchNorm2d):
                mod.weight.data.fill_(1)
                mod.bias.data.zero_()

class wasp(nn.Module):
    def __init__(self, backbone, output_stride, BatchNorm):
        super(wasp,self).__init__()
        if backbone == 'drn' :
            inplanes = 512 
        elif backbone == 'mobilenet' :
            inplanes - 320
        else:
            inplanes = 2048
        if output_stride == 16:
            dilations = [24, 18, 12,  6]
        elif output_stride == 8:
            dilations = [48, 36, 24, 12]
        else: 
            print('build wasp error')

        self.aspp1 = AtrousModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm)
        self.aspp2 = AtrousModule(256, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm)
        self.aspp3 = AtrousModule(256, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm)
        self.aspp4 = AtrousModule(256, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm)

        self.g_avg_pool = nn.Sequential(
                            nn.AdaptiveAvgPool2d((1, 1)),
                            nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
                            nn.ReLU())

        self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
        self.conv2 = nn.Conv2d(256,256,1,bias=False)
        self.bn1 = BatchNorm(256)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self._init_weight()

    def forward(self, x):
        #atrous layers
        x1 = self.aspp1(x);x2 = self.aspp2(x1);x3 = self.aspp3(x2);x4 = self.aspp4(x3)

        #conv layers
        x1 = self.conv2(x1);x2 = self.conv2(x2);x3 = self.conv2(x3);x4 = self.conv2(x4)
    
        x1 = self.conv2(x1);x2 = self.conv2(x2);x3 = self.conv2(x3);x4 = self.conv2(x4)

        x5 = self.g_avg_pool(x)
        x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        return self.dropout(x)
        
    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
                
def build_wasp(backbone, output_stride, BatchNorm):
    return wasp(backbone, output_stride, BatchNorm)   

In [149]:
class Decoder(nn.Module):
    def __init__(self, dataset, num_classes, backbone, BatchNorm,limbsNum):
        super(Decoder, self).__init__()
        if backbone == 'resnet':
            low_level_inplanes = 256
            limbsNum = limbsNum

        self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False)
        self.batch1 = BatchNorm(48)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(2048, 256, 1, bias=False)
        self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                    BatchNorm(256),
                                    nn.ReLU(),
                                    nn.Dropout(0.5),
                                    nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
                                    BatchNorm(256),
                                    nn.ReLU(),
                                    nn.Dropout(0.1),
                                    nn.Conv2d(256, num_classes+1, kernel_size=1, stride=1))


        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self._init_weight()


    def forward(self, x, low_level_feat):

        low_level_feat = self.conv1(low_level_feat)
        low_level_feat = self.batch1(low_level_feat)
        low_level_feat = self.relu(low_level_feat)

        low_level_feat = self.maxpool(low_level_feat)

        x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True)

        x = torch.cat((x, low_level_feat), dim=1)
        x = self.last_conv(x)


        return x

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

def build_decoder(dataset, num_classes, backbone, BatchNorm,limbsNum):
    return Decoder(dataset ,num_classes, backbone, BatchNorm,limbsNum)

In [150]:
def build_backbone(backbone, output_stride, BatchNorm):
    if backbone == 'resnet':
        return resnet.ResNet101(output_stride, BatchNorm)
    else:
        print("build backbone error")

In [151]:
class unipose(nn.Module):
    def __init__(self, dataset, backbone = 'resnet', output_stride = 16, num_classes = 21, sync_bn = True, freeze_bn = False, stride = 8,b = False):
        super(unipose,self).__init__()
        self.stride = stride; self.num_classes = num_classes

        BatchNorm = nn.BatchNorm2d

        self.pool_center   = nn.AvgPool2d(kernel_size=9, stride=8, padding=1)

        self.backbone      = build_backbone(backbone, output_stride, BatchNorm)
        self.wasp          = build_wasp(backbone, output_stride, BatchNorm)
        self.decoder       = build_decoder(dataset, num_classes, backbone, BatchNorm,13)

        if freeze_bn:
            self.freeze_bn()

    def forward(self, input):
        x, low_level_feat = self.backbone(input)
        x = self.wasp(x)
        x = self.decoder(x, low_level_feat)
        #resize
        if self.stride != 8:
            x = F.interpolate(x, size=(input.size()[2:]), mode='bilinear', align_corners=True)

        if b == True:
            return x[:,0:self.num_classes+1,:,:], x[:,self.num_classes+1:,:,:] 
    
        else:
            return x

    def freeze_bn(self):
        for m in self.modules():
            if isinstance(m, SynchronizedBatchNorm2d):
                m.eval()
            elif isinstance(m, nn.BatchNorm2d):
                m.eval()

    def get_1x_lr_params(self):
        modules = [self.backbone]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                        or isinstance(m[1], nn.BatchNorm2d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p

    def get_10x_lr_params(self):
        modules = [self.aspp, self.decoder]
        for i in range(len(modules)):
            for m in modules[i].named_modules():
                if isinstance(m[1], nn.Conv2d) or isinstance(m[1], SynchronizedBatchNorm2d) \
                        or isinstance(m[1], nn.BatchNorm2d):
                    for p in m[1].parameters():
                        if p.requires_grad:
                            yield p


if __name__ == "__main__":
    model = wasp(backbone='resnet', output_stride=16, BatchNorm =nn.BatchNorm2d)
    model.eval()
    #input = torch.rand(1, 3, 513, 513)
    #output = model(input)
    #print(output.size())


In [152]:
def adjust_learning_rate(optimizer, iters, base_lr, gamma, step_size, policy='step', multiple=[1]):
    if policy == 'fixed':
        lr = base_lr
    elif policy == 'step':
        lr = base_lr * (gamma ** (iters // step_size))
    for i, param_group in enumerate(optimizer.param_groups):
        param_group['lr'] = lr * multiple[i]
    return lr

In [153]:
def save_checkpoint(state, is_best, filename='checkpoint'):

    if is_best:
        torch.save(state, filename + '_best.pth.tar')

In [154]:
def get_parameters(model, lr, isdefault=True):

    if isdefault:
        return model.parameters(), [1.]
    lr_1 = []
    lr_2 = []
    lr_4 = []
    lr_8 = []
    params_dict = dict(model.named_parameters())
    for key, value in params_dict.items():
        if ('model1_' not in key) and ('model0.' not in key):
            if key[-4:] == 'bias':
                lr_8.append(value)
            else:
                lr_4.append(value)
        elif key[-4:] == 'bias':
            lr_2.append(value)
        else:
            lr_1.append(value)
    params = [{'params': lr_1, 'lr': lr},
            {'params': lr_2, 'lr': lr * 2.},
            {'params': lr_4, 'lr': lr * 4.},
            {'params': lr_8, 'lr': lr * 8.}]

    return params, [1., 2., 4., 8.]

In [155]:
def draw_paint(im, kpts, mapNumber, epoch, model_arch, dataset):
    colors = [[000,000,255], [000,255,000], [000,000,255], [255,255,000], [255,255,000], [255,000,255], [000,255,000],\
              [255,000,000], [255,255,000], [255,000,255], [000,255,000], [000,255,000], [000,000,255], [255,255,000], [255,000,000]]
    limbSeq = [[ 8, 9], [ 7,12], [12,11], [11,10], [ 7,13], [13,14], [14,15], [ 7, 6], [ 6, 2], [ 2, 1], [ 1, 0], [ 6, 3], [ 3, 4], [ 4, 5], [ 7, 8]]
    for k in kpts:
        x = k[0]
        y = k[1]
        cv2.circle(im, (x, y), radius=3, thickness=-1, color=(0, 0, 255))

    # draw lines
    for i in range(len(limbSeq)):
        cur_im = im.copy()
        limb = limbSeq[i]
        [Y0, X0] = kpts[limb[0]]
        [Y1, X1] = kpts[limb[1]]
        # mX = np.mean([X0, X1])
        # mY = np.mean([Y0, Y1])
        # length = ((X0 - X1) ** 2 + (Y0 - Y1) ** 2) ** 0.5
        # angle = math.degrees(math.atan2(X0 - X1, Y0 - Y1))
        # polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), 4), int(angle), 0, 360, 1)
        # cv2.fillConvexPoly(cur_im, polygon, colors[i])
        # if X0!=0 and Y0!=0 and X1!=0 and Y1!=0:
        #     im = cv2.addWeighted(im, 0.4, cur_im, 0.6, 0)

        if X0!=0 and Y0!=0 and X1!=0 and Y1!=0:
            if i<len(limbSeq)-4:
                cv2.line(cur_im, (Y0,X0), (Y1,X1), colors[i], 5)
            else:
                cv2.line(cur_im, (Y0,X0), (Y1,X1), [0,0,255], 5)

        im = cv2.addWeighted(im, 0.2, cur_im, 0.8, 0)

    cv2.imwrite('samples/WASPpose/Pose/'+str(mapNumber)+'.png', im)


In [156]:
class Resized(object):
    
    def __init__(self, size):
        assert (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2))
        if isinstance(size, int):
            self.size = (size, size)
        else:
            self.size = size

    def get_params(img, output_size):

        height, width, _ = img.shape
        
        return (output_size[0] * 1.0 / height, output_size[1] * 1.0 / width)

    def __call__(self, img, kpt, center):

        ratio = self.get_params(img, self.size)

        return resize(img, kpt, center, ratio)

In [157]:
def guassian_kernel(size_w, size_h, center_x, center_y, sigma):
    gridy, gridx = np.mgrid[0:size_h, 0:size_w]
    D2 = (gridx - center_x) ** 2 + (gridy - center_y) ** 2
    return np.exp(-D2 / 2.0 / sigma / sigma)


In [158]:
class mpii(data.Dataset):
    def __init__(self, root_dir, sigma, is_train, transform=None):
        self.width       = 368
        self.height      = 368
        self.transformer = transform
        self.is_train    = is_train
        self.sigma       = sigma
        self.parts_num   = 16
        self.stride      = 8

        self.labels_dir  = root_dir 
        self.images_dir  = 'images/'
        

        self.videosFolders = {}
        self.labelFiles    = {}
        self.full_img_List = {}
        self.numPeople     = []
        
        
        with open(self.labels_dir) as anno_file:
            self.anno = json.load(anno_file)
        
        if is_train == True:
            self.train_list = []
            for idx,val in enumerate(self.anno):
                self.train_list.append(idx)
                
        elif is_train == False: 
            self.val_list = []
            for idx,val in enumerate(self.anno):
                self.val_list.append(idx)
            
                
        if is_train == True:
            self.img_List = self.train_list
            print("Train images ",len(self.img_List))

        elif is_train == False:
            self.img_List = self.val_list
            print("Val   images ",len(self.img_List))
        
    def __getitem__(self, index):
        scale_factor = 0.25

        variable = self.anno[self.img_List[index]]

        while not os.path.isfile(self.labels_dir + variable['img_paths'][:-4]+'.jpg'):
            index = index - 1
            variable = self.anno[self.img_List[index]]

        img_path  = self.images_dir + variable['image'] 
        points   = torch.Tensor(variable['joints'])
        center   = torch.Tensor(variable['center'])
        scale    = variable['scale']

        if center[0] != -1:
            center[1] = center[1] + 15*scale
            scale     = scale*1.25

        nParts = points.size(0)
        img    = cv2.imread(img_path)

        kpt = points
        if img.shape[0] != 368 or img.shape[1] != 368:
            kpt[:,0] = kpt[:,0] * (368/img.shape[1])
            kpt[:,1] = kpt[:,1] * (368/img.shape[0])
            img = cv2.resize(img,(368,368))
        height, width, _ = img.shape
        heatmap = np.zeros((int(height/self.stride), int(width/self.stride), int(len(kpt)+1)), dtype=np.float32)
        for i in range(len(kpt)):
            # resize from 368 to 46
            x = int(kpt[i][0]) * 1.0 / self.stride
            y = int(kpt[i][1]) * 1.0 / self.stride
            heat_map = guassian_kernel(size_h=int(height/self.stride),size_w=int(width/self.stride), center_x=x, center_y=y, sigma=self.sigma)
            heat_map[heat_map > 1] = 1
            heat_map[heat_map < 0.0099] = 0
            heatmap[:, :, i + 1] = heat_map
        centermap = np.zeros((int(height/self.stride), int(width/self.stride), 1), dtype=np.float32)
        center_map = guassian_kernel(size_h=int(height/self.stride), size_w=int(width/self.stride), center_x=int(center[0]/self.stride), center_y=int(center[1]/self.stride), sigma=3)
        center_map[center_map > 1] = 1
        center_map[center_map < 0.0099] = 0
        centermap[:, :, 0] = center_map

        orig_img = cv2.imread(img_path)

        img = Mytransforms.normalize(Mytransforms.to_tensor(img), [128.0, 128.0, 128.0],
                                     [256.0, 256.0, 256.0])
        heatmap   = Mytransforms.to_tensor(heatmap)
        centermap = Mytransforms.to_tensor(centermap)

        return img, heatmap, centermap, img_path
        
    def __len__(self):
        return len(self.img_List)        
    

In [159]:
def getDataloader(dataset, train_dir, val_dir, sigma, stride, workers, batch_size,*test_dir,):
    if dataset == 'MPII':
        train_loader = torch.utils.data.DataLoader(
                                            mpii(train_dir, sigma, True,
                                            Resized(368)),
                                            batch_size  = batch_size, shuffle=True,num_workers = workers, pin_memory=True)
    
        val_loader   = torch.utils.data.DataLoader(
                                            mpii(val_dir, sigma, False,
                                            Resized(368)),batch_size  = 1, shuffle=True,num_workers = 1, pin_memory=True)
    
        #test_loader   = torch.utils.data.DataLoader(
                                          #  mpii(test_dir, sigma, "Val",
                                          #  Resized(368)),batch_size  = 1, shuffle=True,num_workers = 1, pin_memory=True)
        return train_loader, val_loader #, test_loader


In [160]:
def printAccuracies(mAP, AP, mPCKh, PCKh, mPCK, PCK, dataset):
    if dataset == "MPII":
        print("\nmAP:   %.2f%%" % (mAP*100))
        print("APs:   %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%"\
            % (AP[0]*100,AP[1]*100,AP[2]*100,AP[3]*100,AP[4]*100,AP[5]*100,AP[6]*100,AP[7]*100,AP[8]*100,AP[9]*100,PCKh[10]*100,\
                AP[11]*100,AP[12]*100,AP[13]*100,AP[14]*100,AP[15]*100,AP[16]*100))

        print("mPCK:  %.2f%%" % (mPCK*100))
        print("PCKs:  %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%"\
            % (PCK[0]*100,PCK[1]*100,PCK[2]*100,PCK[3]*100,PCK[4]*100,PCK[5]*100,PCK[6]*100,PCK[7]*100,PCK[8]*100,PCK[9]*100,PCK[10]*100,\
                PCK[11]*100,PCK[12]*100,PCK[13]*100,PCK[14]*100,PCK[15]*100,PCK[16]*100))

        print("mPCKh: %.2f%%" % (mPCKh*100))
        print("PCKhs: %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%, %2.2f%%"\
            % (PCKh[0]*100,PCKh[1]*100,PCKh[2]*100,PCKh[3]*100,PCKh[4]*100,PCKh[5]*100,PCKh[6]*100,PCKh[7]*100,PCKh[8]*100,PCKh[9]*100,\
                PCKh[10]*100,PCKh[11]*100,PCKh[12]*100,PCKh[13]*100,PCKh[14]*100,PCKh[15]*100,PCKh[16]*100))

In [161]:
def getOutImages(heat, input_var, img_path, outName):
    heat = F.interpolate(heat, size=input_var.size()[2:], mode='bilinear', align_corners=True)

    heat = heat.detach().cpu().numpy()

    heat = heat[0].transpose(1,2,0)


    for i in range(heat.shape[0]):
        for j in range(heat.shape[1]):
            for k in range(heat.shape[2]):
                if heat[i,j,k] < 0:
                    heat[i,j,k] = 0
                

    im       = cv2.resize(cv2.imread(img_path[0]),(368,368))

    heatmap = []
    for i in range(15):
        heatmap = cv2.applyColorMap(np.uint8(255*heat[:,:,i]), cv2.COLORMAP_JET)
        im_heat  = cv2.addWeighted(im, 0.6, heatmap, 0.4, 0)
        cv2.imwrite('samples/WASPpose/heat/'+outName+'_'+str(i)+'.png', im_heat)

In [162]:
def get_model_summary(model, *input_tensors, item_length=26, verbose=False):
    """
    :param model:
    :param input_tensors:
    :param item_length:
    :return:
    """

    summary = []

    ModuleDetails = namedtuple(
        "Layer", ["name", "input_size", "output_size", "num_parameters", "multiply_adds"])
    hooks = []
    layer_instances = {}

    def add_hooks(module):

        def hook(module, input, output):
            class_name = str(module.__class__.__name__)

            instance_index = 1
            if class_name not in layer_instances:
                layer_instances[class_name] = instance_index
            else:
                instance_index = layer_instances[class_name] + 1
                layer_instances[class_name] = instance_index

            layer_name = class_name + "_" + str(instance_index)

            params = 0

            if class_name.find("Conv") != -1 or class_name.find("BatchNorm") != -1 or \
               class_name.find("Linear") != -1:
                for param_ in module.parameters():
                    params += param_.view(-1).size(0)

            flops = "Not Available"
            if class_name.find("Conv") != -1 and hasattr(module, "weight"):
                flops = (
                    torch.prod(
                        torch.LongTensor(list(module.weight.data.size()))) *
                    torch.prod(
                        torch.LongTensor(list(output.size())[2:]))).item()
            elif isinstance(module, nn.Linear):
                flops = (torch.prod(torch.LongTensor(list(output.size()))) \
                         * input[0].size(1)).item()

            if isinstance(input[0], list):
                input = input[0]
            if isinstance(output, list):
                output = output[0]

            summary.append(
                ModuleDetails(
                    name=layer_name,
                    input_size=list(input[0].size()),
                    output_size=list(output.size()),
                    num_parameters=params,
                    multiply_adds=flops)
            )
            if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != model:
                hooks.append(module.register_forward_hook(hook))

    model.eval()
    model.apply(add_hooks)

    space_len = item_length

    model(*input_tensors)
    for hook in hooks:
        hook.remove()

    details = ''
    if verbose:
        details = "Model Summary" + \
            os.linesep + \
            "Name{}Input Size{}Output Size{}Parameters{}Multiply Adds (Flops){}".format(
                ' ' * (space_len - len("Name")),
                ' ' * (space_len - len("Input Size")),
                ' ' * (space_len - len("Output Size")),
                ' ' * (space_len - len("Parameters")),
                ' ' * (space_len - len("Multiply Adds (Flops)"))) \
                + os.linesep + '-' * space_len * 5 + os.linesep
    params_sum = 0
    flops_sum = 0
    for layer in summary:
        params_sum += layer.num_parameters
        if layer.multiply_adds != "Not Available":
            flops_sum += layer.multiply_adds
        if verbose:
            details += "{}{}{}{}{}{}{}{}{}{}".format(
                layer.name,
                ' ' * (space_len - len(layer.name)),
                layer.input_size,
                ' ' * (space_len - len(str(layer.input_size))),
                layer.output_size,
                ' ' * (space_len - len(str(layer.output_size))),
                layer.num_parameters,
                ' ' * (space_len - len(str(layer.num_parameters))),
                layer.multiply_adds,
                ' ' * (space_len - len(str(layer.multiply_adds)))) \
                + os.linesep + '-' * space_len * 5 + os.linesep

    details += os.linesep \
        + "Total Parameters: {:,}".format(params_sum) \
        + os.linesep + '-' * space_len * 5 + os.linesep
    details += "Total Multiply Adds (For Convolution and Linear Layers only): {:,} GFLOPs".format(flops_sum/(1024**3)) \
        + os.linesep + '-' * space_len * 5 + os.linesep
    details += "Number of Layers" + os.linesep
    for layer in layer_instances:
        details += "{} : {} layers   ".format(layer, layer_instances[layer])

    return details


In [163]:
def get_kpts(maps, img_h = 368.0, img_w = 368.0):

    # maps (1,15,46,46)
    maps = maps.clone().cpu().data.numpy()
    map_6 = maps[0]

    kpts = []
    for m in map_6[1:]:
        h, w = np.unravel_index(m.argmax(), m.shape)
        x = int(w * img_w / m.shape[1])
        y = int(h * img_h / m.shape[0])
        kpts.append([x,y])
    return kpts

In [164]:
class AverageMeter(object):
    """ Computes ans stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [165]:
class Trainer(object):
    def __init__(self, pretrained,dataset,train_dir,val_dir,model_name,model_arch = 'unipose'):

        self.train_dir    = ""
        self.val_dir      = ""
        self.dataset      = "MPII"


        self.workers      = 1
        self.weight_decay = 0.0005
        self.momentum     = 0.9
        self.batch_size   = 8
        self.lr           = 0.0001
        self.gamma        = 0.333
        self.step_size    = 13275
        self.sigma        = 3
        self.stride       = 8
        #cudnn.benchmark   = True
    
        if self.dataset == "MPII":
            self.numClasses  = 16
        self.train_loader, self.val_loader = getDataloader(self.dataset, 'train.json','valid.json', self.sigma, self.stride, self.workers, self.batch_size)
        model = unipose(self.dataset, num_classes=self.numClasses,backbone='resnet',output_stride=16,sync_bn=True,freeze_bn=False, stride=self.stride)
        self.model  = model.cuda()
        self.criterion   = nn.MSELoss().cuda()
        self.optimizer   = torch.optim.Adam(self.model.parameters(), lr=self.lr)
        self.best_model  = 12345678.9
        self.iters       = 0
        if self.args.pretrained is not None:
            checkpoint = torch.load(self.args.pretrained)
            p = checkpoint['state_dict']

            state_dict = self.model.state_dict()
            model_dict = {}

            for k,v in p.items():
                if k in state_dict:
                    model_dict[k] = v

            state_dict.update(model_dict)
            self.model.load_state_dict(state_dict)
            self.isBest = 0
        self.bestPCK  = 0
        self.bestPCKh = 0

    # Print model summary and metrics
    #dump_input = torch.rand(([1, 3, 368, 368]))
   # print(get_model_summary(self.modelmodel, dump_input))
    def training(self, epoch):
        train_loss = 0.0
        self.model.train()
        print("epoch " + str(epoch) + ':') 
        tbar = tqdm(self.train_loader)
        for i, (input, heatmap, centermap, img_path) in enumerate(tbar):
            learning_rate = adjust_learning_rate(self.optimizer, self.iters, self.lr, policy='step',
                                                 gamma=self.gamma, step_size=self.step_size)

            input_var     =     input.cuda()
            heatmap_var   =    heatmap.cuda()
            self.optimizer.zero_grad()
            heat = self.model(input_var)
            loss_heat   = self.criterion(heat,  heatmap_var)
            loss = loss_heat
            train_loss += loss_heat.item()
            loss.backward()
            self.optimizer.step()
            tbar.set_description('Train loss: %.6f' % (train_loss / ((i + 1)*self.batch_size)))
            self.iters += 1
            if i == 10000:
            	break
    def validation(self, epoch):
        self.model.eval()
        tbar = tqdm(self.val_loader, desc='\r')
        val_loss = 0.0
        
        AP    = np.zeros(self.numClasses+1)
        PCK   = np.zeros(self.numClasses+1)
        PCKh  = np.zeros(self.numClasses+1)
        count = np.zeros(self.numClasses+1)

        cnt = 0
        for i, (input, heatmap, centermap, img_path) in enumerate(tbar):

            cnt += 1

            input_var     =      input.cuda()
            heatmap_var   =    heatmap.cuda()
            self.optimizer.zero_grad()

            heat = self.model(input_var)
            loss_heat   = self.criterion(heat,  heatmap_var)

            loss = loss_heat

            val_loss += loss_heat.item()

            tbar.set_description('Val   loss: %.6f' % (val_loss / ((i + 1)*self.batch_size)))

            acc, acc_PCK, acc_PCKh, cnt, pred, visible = evaluate.accuracy(heat.detach().cpu().numpy(), heatmap_var.detach().cpu().numpy(),0.2,0.5, self.dataset)

            AP[0]     = (AP[0]  *i + acc[0])      / (i + 1)
            PCK[0]    = (PCK[0] *i + acc_PCK[0])  / (i + 1)
            PCKh[0]   = (PCKh[0]*i + acc_PCKh[0]) / (i + 1)
            for j in range(1,self.numClasses+1):
                if visible[j] == 1:
                    AP[j]     = (AP[j]  *count[j] + acc[j])      / (count[j] + 1)
                    PCK[j]    = (PCK[j] *count[j] + acc_PCK[j])  / (count[j] + 1)
                    PCKh[j]   = (PCKh[j]*count[j] + acc_PCKh[j]) / (count[j] + 1)
                    count[j] += 1

            mAP     =   AP[1:].sum()/(self.numClasses)
            mPCK    =  PCK[1:].sum()/(self.numClasses)
            mPCKh   = PCKh[1:].sum()/(self.numClasses)
	
        printAccuracies(mAP, AP, mPCKh, PCKh, mPCK, PCK, self.dataset)
            
        PCKhAvg = PCKh.sum()/(self.numClasses+1)
        PCKAvg  =  PCK.sum()/(self.numClasses+1)
        if mAP > self.isBest:
            self.isBest = mAP
            save_checkpoint({'state_dict': self.model.state_dict()}, self.isBest, self.args.model_name)
            print("Model saved to "+self.args.model_name)

        if mPCKh > self.bestPCKh:
            self.bestPCKh = mPCKh
        if mPCK > self.bestPCK:
            self.bestPCK = mPCK

        print("Best AP = %.2f%%; PCK = %2.2f%%; PCKh = %2.2f%%" % (self.isBest*100, self.bestPCK*100,self.bestPCKh*100))

    def test(self,epoch):
        self.model.eval()
        print("Testing") 

        for idx in range(1):
            print(idx,"/",2000)
            img_path = '/PATH/TO/TEST/IMAGE'

            center   = [184, 184]

            img  = np.array(cv2.resize(cv2.imread(img_path),(368,368)), dtype=np.float32)
            img  = img.transpose(2, 0, 1)
            img  = torch.from_numpy(img)
            mean = [128.0, 128.0, 128.0]
            std  = [256.0, 256.0, 256.0]
            for t, m, s in zip(img, mean, std):
                t.sub_(m).div_(s)

            img       = torch.unsqueeze(img, 0)

            self.model.eval()

            input_var   = img.cuda()

            heat = self.model(input_var)

            heat = F.interpolate(heat, size=input_var.size()[2:], mode='bilinear', align_corners=True)

            kpts = get_kpts(heat, img_h=368.0, img_w=368.0)
            draw_paint(img_path, kpts, idx, epoch, self.model_arch, self.dataset)

            heat = heat.detach().cpu().numpy()

            heat = heat[0].transpose(1,2,0)

            for i in range(heat.shape[0]):
                for j in range(heat.shape[1]):
                    for k in range(heat.shape[2]):
                        if heat[i,j,k] < 0:
                            heat[i,j,k] = 0
                        

            im  = cv2.resize(cv2.imread(img_path),(368,368))

            heatmap = []
            for i in range(self.numClasses+1):
                heatmap = cv2.applyColorMap(np.uint8(255*heat[:,:,i]), cv2.COLORMAP_JET)
                im_heat  = cv2.addWeighted(im, 0.6, heatmap, 0.4, 0)
                cv2.imwrite('samples/heat/unipose'+str(i)+'.png', im_heat)


In [166]:
starter_epoch =    0
epochs        =  1
trainer = Trainer(None,'MPII', '', '', None, 'unipose' )
for epoch in range(starter_epoch, epochs):
    trainer.training(epoch)
    trainer.validation(epoch)

Train images  22246
Val   images  2958


AssertionError: Torch not compiled with CUDA enabled

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--pretrained', default=None,type=str, dest='pretrained')
parser.add_argument('--dataset', type=str, dest='dataset', default='MPII')
parser.add_argument('--train_dir', default='',type=str, dest='')
parser.add_argument('--val_dir', type=str, dest='val_dir', default='')
parser.add_argument('--model_name', default=None, type=str)
parser.add_argument('--model_arch', default='unipose', type=str)


args = parser.parse_args()