In [None]:
from matplotlib.transforms import ScaledTranslation
import torch
import torch.nn as nn
import torch.utils.data as torch_data
import numpy as np
import torch.nn.functional as F 
import os
import json
from tqdm import tqdm
from google.colab import drive
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import random
import math
import cv2
import skimage.filters as filters
import skimage.transform as transform
import sys
np.set_printoptions(threshold=sys.maxsize)

#Set device to GPU if possible
cuda = torch.cuda.is_available()
device = 'cuda:0' if cuda else 'cpu'

#print(cuda)
print("GPU availability: {}; automation will be using: {} ".format(cuda, device), end= '')

#Path to where data is stored
drive.mount('/content/gdrive/', force_remount = True)
dir_path = "/content/gdrive/MyDrive/Colab Notebooks/Data/"

#Set automation to Train or Test
Train = False

def Load_Data():

  if(Train == True):
    dir = os.path.join(dir_path, "train_annotation.json")
  else:
    dir = os.path.join(dir_path, "val_annotation.json")

  with open(dir) as f:
      dictionary = json.load(f)
      feature_list = [item for item in dictionary['data']]
      feature_list = np.asarray(feature_list)

  return(feature_list)

def get_image(feature_list, index):

  if(Train == True):
    dir_image_folder = os.path.join(dir_path, "train")
  else:
    dir_image_folder = os.path.join(dir_path, "val")
    
  # print("dir_image_folder: ",dir_image_folder)

  feature = feature_list[index]

  dir_image = os.path.join(dir_image_folder, feature["file"])

  img = mpimg.imread(dir_image)

  #plt.imshow(img)

  return img

#Return batches with image filepaths. Will need to "get_image" by batch as
#RAM capacity is surpassed if we attempt to get all images within the create batch
#method.

def create_batches(data, batch_size):

  data_count = len(data)
  index_array = random.sample(range(0, data_count), data_count)

  mini_batch_count = math.floor(data_count/batch_size)

  mini_batch_x = []

  for batch in range(int(mini_batch_count)):
    batch_x = []
    try:
      for value in range(batch_size):
        batch_x.append(index_array[value + batch * batch_size])
      batch_x = np.asarray(batch_x)
      mini_batch_x.append(batch_x)
    except:
      batch_x = np.asarray(batch_x)
      mini_batch_x.append(batch_x)
  mini_batch_x = np.asarray(mini_batch_x)
  return mini_batch_x
#torch_data.dataloader(data_class, batch_size = 1, shuffle = True)

def get_batch_data(data, batch, size):
  batch_data_train_images = []
  batch_data_heat_maps = []
  batch_data_center_heat_map = []

  heat_map_size = (32, 32)
  center_map_size = (256, 256)

  for value in range(len(batch)):
    print(batch[value])
    img = get_image(data, batch[value])
    cropped_image, cropped_joints, cropped_center = crop_image(img, data[batch[value]], size)
    #Create ground truth heatmap and center heat map
    heat_maps, center_heat_map = get_heatmap([np.zeros(heat_map_size), np.zeros(center_map_size)], cropped_joints/8, cropped_center)
    train_image = np.transpose(cropped_image, (2, 0, 1))/255.0
    batch_data_train_images.append(train_image)
    batch_data_heat_maps.append(heat_maps)
    batch_data_center_heat_map.append(center_heat_map)

  batch_data_train_images = np.asarray(batch_data_train_images)
  batch_data_heat_maps = np.asarray(batch_data_heat_maps)
  batch_data_center_heat_map = np.asarray(batch_data_center_heat_map)

  return batch_data_train_images, batch_data_heat_maps, batch_data_center_heat_map

def crop_image(img, features, image_size):
  joint_points = features["landmarks"]
  bbox = features["bbox"]
  visibility = np.asarray(features["visibility"])

  x = []
  y = []

  for value in range(0, len(joint_points), 2):
    x.append(joint_points[value])
  width_points = np.transpose(np.asarray(x))

  for value in range(1, len(joint_points), 2):
    y.append(joint_points[value])
  height_points = np.transpose(np.asarray(y))

  Height, Width = img.shape[0], img.shape[1]

  bbox_x1 = math.floor(bbox[0])
  bbox_x2 = math.floor(bbox[0] + bbox[2])
  bbox_y1 = math.floor(bbox[1])
  bbox_y2 = math.floor(bbox[1] + bbox[3])

  cropped_x = bbox_x2 - bbox_x1
  cropped_y = bbox_y2 - bbox_y1

  if(cropped_y > cropped_x):
    cropped_difference = cropped_y - cropped_x
    bbox_x1 = bbox_x1 - math.floor(cropped_difference/2)
    bbox_x2 = bbox_x2 + math.floor(cropped_difference/2)
  elif(cropped_x > cropped_y):
    cropped_difference = cropped_x - cropped_y
    bbox_y1 = bbox_y1 - math.floor(cropped_difference/2)
    bbox_y2 = bbox_y2 + math.floor(cropped_difference/2)

  #Sift our Width and Height points for any zero values
  width_points = np.where(width_points != 0, width_points - bbox_x1, width_points)
  height_points = np.where(height_points != 0, height_points - bbox_y1, height_points)

  #Determine if padding is needed for image
  padding = 0

  if((bbox_y1 < 0) or (bbox_x1 < 0) or (bbox_y2 > Height) or (bbox_x2 > Width)):
    padding = math.floor(max(-bbox_y1, -bbox_x1, bbox_y2 - Height, bbox_x2 - Width))
    img = np.pad(img, ((padding, padding), (padding, padding), (0, 0)))

  img = img[bbox_y1 + padding : bbox_y2 + padding, bbox_x1 + padding : bbox_x2 + padding]

  bbox[0] = bbox[0] - bbox_x1
  bbox[1] = bbox[1] - bbox_y1

  #Resize our image
  Height, Width = img.shape[0], img.shape[1]
  width_points = (width_points * image_size)/Width
  height_points = (height_points * image_size)/Width

  center = np.array((1, 2))

  center[0] = (abs(math.floor((bbox_x2 - bbox_x1)/2))) * (image_size/Width)
  center[1] = (abs(math.floor((bbox_y2 - bbox_y1)/2))) * (image_size/Height)

  img = cv2.resize(img, (image_size, image_size))

  joint_points = [[width_points[i], height_points[i]] for i in range(len(height_points))]

  return img, np.asarray(joint_points), np.asarray(center)

def get_heatmap(img, joint_points, cropped_center):
  heat_map_size = img[0]
  #print(heat_map_size.shape)
  Height, Width = heat_map_size.shape[0], heat_map_size.shape[1]
  #print(joint_points.shape)

  #Build shell for heatmap of each joint_point
  heat_maps = np.zeros((Height, Width, (len(joint_points) + 1)))

  for joint in range(len(joint_points)):
    if((joint_points[joint, 0] == 0) and (joint_points[joint, 1] == 0)):
      continue

    if(joint_points[joint, 1] >= Height):
      joint_points[joint, 1] = Height - 1
    if(joint_points[joint, 0] >= Width):
      joint_points[joint, 0] = Width - 1
    
    #Build heatmap of joint
    heat_map = heat_maps[:, :, joint]
    x = math.floor(joint_points[joint, 1])
    y = math.floor(joint_points[joint, 0])

    #Set heatmap of joint within heatmap image
    heat_map[x][y] = 1
    heat_map = filters.gaussian(heat_map, sigma = 2)

    #scale image to [0, 1]
    scale = np.max(heat_map)
    heat_map = heat_map/scale
    heat_maps[:, :, joint] = heat_map

  #add background dimensions
  heat_maps[:, :, len(joint_points)] = 1 - np.max(heat_maps[:, :, :len(joint_points)], axis = 2)
  heat_maps = np.transpose(heat_maps, (2, 0, 1))


  #Center image
  center_heat_map = img[1] #256x256

  Height, Width = center_heat_map.shape[0], center_heat_map.shape[1]
  x = math.floor(cropped_center[1]); y = math.floor(cropped_center[0])
  center_heat_map[x][y] = 1

  center_heat_map = filters.gaussian(center_heat_map, sigma = 2)
  scale = np.max(center_heat_map)
  center_heat_map = center_heat_map/scale
  center_heat_map = np.expand_dims(center_heat_map, axis = 0)
  #print(center_heat_map.shape)
  return heat_maps, center_heat_map

def view_heat_maps(img, heat_maps):
  Height, Width = img.shape[0], img.shape[1]

  heat_maps = np.transpose(heat_maps)

  dictionary_of_names = {0: 'Image', 1: 'Right Eye', 2: 'Left Eye', 3: 'Nose', 4: 'Head',
    5: 'Neck', 6: 'Right Shoulder', 7: 'Right Elbow', 8: 'Right Wrist', 9: 'Left Shoulder',
    10: 'Left Elbow', 11: 'Left Wrist', 12: 'Hip', 13: 'Right Knee', 14: 'Right Ankle',
    15: 'Left Knee', 16: 'Left Ankle', 17: 'Tail'}

  #print(heat_maps.shape)
  if(heat_maps.shape[0] != Height):
    heat_maps = transform.resize(heat_maps, (Height, Width))
  elif(heat_maps.shape[1] != Width):
    heat_maps = transform.resize(heat_maps, (Height, Width))

  for value in range((heat_maps.shape[2])): #256x256x18
    #print(value)
    plt.subplot(4, 5, value + 1)
    plt.title(dictionary_of_names[value], fontdict = {"fontsize" : 12})
    plt.axis("off")

    if value == 0:
      plt.imshow(img)
    else:
      plt.imshow(heat_maps[:, :, value - 1])
  plt.show()


In [None]:
def crop_image_test(img, features, image_size): #crop_image function for testing-relevant features

  bbox = features["bbox"]
  bb = bbox.copy()

  Height, Width = img.shape[0], img.shape[1]

  bbox_x1 = math.floor(bbox[0])
  bbox_x2 = math.floor(bbox[0] + bbox[2])
  bbox_y1 = math.floor(bbox[1])
  bbox_y2 = math.floor(bbox[1] + bbox[3])

  cropped_x = bbox_x2 - bbox_x1
  cropped_y = bbox_y2 - bbox_y1

  if(cropped_y > cropped_x):
    cropped_difference = cropped_y - cropped_x
    bbox_x1 = bbox_x1 - math.floor(cropped_difference/2)
    bbox_x2 = bbox_x2 + math.floor(cropped_difference/2)
  elif(cropped_x > cropped_y):
    cropped_difference = cropped_x - cropped_y
    bbox_y1 = bbox_y1 - math.floor(cropped_difference/2)
    bbox_y2 = bbox_y2 + math.floor(cropped_difference/2)

  #Determine if padding is needed for image
  padding = 0

  if((bbox_y1 < 0) or (bbox_x1 < 0) or (bbox_y2 > Height) or (bbox_x2 > Width)):
    padding = math.floor(max(-bbox_y1, -bbox_x1, bbox_y2 - Height, bbox_x2 - Width))
    img = np.pad(img, ((padding, padding), (padding, padding), (0, 0)))

  img = img[bbox_y1 + padding : bbox_y2 + padding, bbox_x1 + padding : bbox_x2 + padding]

  bbox[0] = bbox[0] - bbox_x1
  bbox[1] = bbox[1] - bbox_y1

  #Resize our image
  Height, Width = img.shape[0], img.shape[1]

  center = np.array((1, 2))

  center[0] = (abs(math.floor((bbox_x2 - bbox_x1)/2))) * (image_size/Width)
  center[1] = (abs(math.floor((bbox_y2 - bbox_y1)/2))) * (image_size/Height)

  img = cv2.resize(img, (image_size, image_size))

  return img, np.asarray(center), bb

def get_cmap(img, cropped_center):
  #Center image
  center_heat_map = img #256x256

  Height, Width = center_heat_map.shape[0], center_heat_map.shape[1]
  x = math.floor(cropped_center[1]); y = math.floor(cropped_center[0])
  center_heat_map[x][y] = 1

  center_heat_map = filters.gaussian(center_heat_map, sigma = 2)
  scale = np.max(center_heat_map)
  center_heat_map = center_heat_map/scale
  center_heat_map = np.expand_dims(center_heat_map, axis = 0)
  #print(center_heat_map.shape)
  return center_heat_map

def get_landmarks(pred, bb): # pred is predicted heat map, bb is bounding box from json
  landmarks = []
  h, w = (bb[3], bb[2])  # (h, w)

  if h > w:
    scale = h / pred.shape[1]
    offset = np.array([0, (h-w)//2])
  else:
    scale = w / pred.shape[1]
    offset = np.array([(w-h)//2, 0])

  for i in range(17): #num_joints
    y, x = np.unravel_index(np.argmax(pred[i,:,:]), (pred.shape[1], pred.shape[2]))

    y = y * scale - offset[0] + bb[1]
    x = x * scale - offset[1] + bb[0]
    
    landmarks.append(int(x))
    landmarks.append(int(y))
  
  return landmarks

class TestData(Dataset):
    def __init__(self):
        super(TestData, self).__init__()
        dir = os.path.join(dir_path, 'test_prediction.json')
        with open(dir) as f:
            dictionary = json.load(f)
            self.features = [item for item in dictionary['data']]

    def __getitem__(self, idx):
        f_d = self.features[idx]
        im_dir = os.path.join(dir_path, 'test', f_d['file'])
        im = mpimg.imread(im_dir)

        crop_size = 256
        cm_size = (256, 256)

        c_im, c_cen, bb = crop_image_test(im, f_d, crop_size)
        cmap = get_cmap(np.zeros(cm_size), c_cen)
        c_im = np.transpose(c_im, (2, 0, 1))/255.0

        return c_im, cmap, bb

    def __len__(self):
        return len(self.features)

class ValData(Dataset):
    def __init__(self):
        super(ValData, self).__init__()
        dir = os.path.join(dir_path, 'val_prediction.json')
        with open(dir) as f:
            dictionary = json.load(f)
            self.features = [item for item in dictionary['data']]

    def __getitem__(self, idx):
        f_d = self.features[idx]
        im_dir = os.path.join(dir_path, 'val', f_d['file'])
        im = mpimg.imread(im_dir)

        crop_size = 256
        cm_size = (256, 256)

        c_im, c_cen, bb = crop_image_test(im, f_d, crop_size)
        cmap = get_cmap(np.zeros(cm_size), c_cen)
        c_im = np.transpose(c_im, (2, 0, 1))/255.0

        return c_im, cmap, bb

    def __len__(self):
        return len(self.features)

def predict_test():
  model.eval()
  test_data = DataLoader(ValData(), batch_size=batch_size, shuffle=False, num_workers=4)
  dir = os.path.join(dir_path, 'val_prediction.json')

  with open(dir) as f:
    dictionary = json.load(f)

  for i, (im, center, bb) in enumerate(test_data):
    # print(bb)
    print('batch ' + str(i + 1) + '/' + str(len(test_data)))
    im = im.float().to(device)
    center = center.float().to(device)
    
    pred = model(im, center)

    for j in range(len(pred)):
      #print(len(pred))
      dictionary['data'][i*batch_size+j]['landmarks'] = get_landmarks(pred[j,-1,:,:,:].cpu().detach().numpy(), [bb[0][j],bb[1][j],bb[2][j],bb[3][j]])

    with open(dir, 'w') as out:
      json.dump(dictionary, out)


In [None]:
predict_test()

# Proposed CPM

In [None]:
import torch
import torch.nn as nn
import torch.utils.data as torch_data
import numpy as np
import torch.nn.functional as F 

class ProposedCPM(nn.Module): # CPM model inherits nn.Module: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
  def __init__(self, num_joints):
    super(ProposedCPM, self).__init__()
    self.num_joints = num_joints

    self.cpool = nn.AvgPool2d(4)                               # avg pool operation with kernel size 4 for center_image

    # To extract features (x) for every stage
    self.x_conv1 = nn.Conv2d(3, 128, 9, padding=4)             # (input channels, output channels, kernel size, padding) // output channels = number of kernels used on image, might be good to play around with values
    self.x_pool1 = nn.MaxPool2d(2)                             # max pool operation with kernel size 2
    self.x_conv2 = nn.Conv2d(128, 128, 9, padding=4)
    self.x_pool2 = nn.MaxPool2d(2)                 
    self.x_conv3 = nn.Conv2d(128, 128, 9, padding=4)
    self.x_pool3 = nn.MaxPool2d(2)
    self.x_conv4 = nn.Conv2d(128, 128, 5, padding=2)
    self.x_upcv1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
    self.x_conv5 = nn.Conv2d(192, 32, 3, padding=1)
    self.x_bnorm = nn.BatchNorm2d(32)
    # self.x_upcv2 = nn.ConvTranspose2d(64, 32, 2, stride=2)
    # self.x_conv6 = nn.Conv2d(160, 32, 3, padding=1)  

    # Stage 1 convolutions:
    self.s1_conv1 = nn.Conv2d(32, 512, 9, padding=4)
    self.s1_conv2 = nn.Conv2d(512, 512, 1)
    self.s1_conv3 = nn.Conv2d(512, num_joints+1, 1)

    # Stage >= 2 (t) convolutions:
    self.st_conv1 = nn.Conv2d(34 + num_joints, 128, 11, padding=5)  # input channels = output channels from image features (x) + output channels from previous stage + 1 channel for loss
    self.st_conv2 = nn.Conv2d(128, 128, 11, padding=5)
    self.st_conv3 = nn.Conv2d(128, 128, 11, padding=5)
    self.st_conv4 = nn.Conv2d(128, 128, 1)
    self.st_conv5 = nn.Conv2d(128, num_joints + 1, 1)

  def extract_features(self, im):                                   # input im shape = (N, 3, H, W) // output shape = (N, 32, H/8, W/8)    N = batch size
    x = self.x_pool1(F.relu(self.x_conv1(im)))
    x1 = self.x_pool2(F.relu(self.x_conv2(x)))
    x2 = self.x_pool3(F.relu(self.x_conv3(x1)))
    x2 = F.relu(self.x_conv4(x2))
    x2 = self.x_upcv1(x2)
    x2 = F.relu(self.x_bnorm(self.x_conv5(torch.cat([x1, x2], dim=1))))
    return x2
  
  def stage_1(self, x):                                             # input x shape = (N, 32, H/8, W/8) // output shape = (N, num_joints + 1, H/8, W/8)
    x = F.relu(self.s1_conv1(x))
    x = F.relu(self.s1_conv2(x))
    return F.relu(self.s1_conv3(x))

  def stage_t(self, x):                                             # input x shape = (N, 34 + num_joints, H/8, W/8) // output shape = (N, num_joints + 1, H/8, W/8)
    x = F.relu(self.st_conv1(x))
    x = F.relu(self.st_conv2(x))
    x = F.relu(self.st_conv3(x))
    x = F.relu(self.st_conv4(x))
    return F.relu(self.st_conv5(x))
  
  def forward(self, im, center_image): # for 6 stages
    cpool = self.cpool(center_image)

    s1_maps = self.stage_1(self.extract_features(im))
    x = self.extract_features(im)
    s2_maps = self.stage_t(torch.cat([x, s1_maps, cpool], dim=1))
    s3_maps = self.stage_t(torch.cat([x, s2_maps, cpool], dim=1))
    s4_maps = self.stage_t(torch.cat([x, s3_maps, cpool], dim=1))
    s5_maps = self.stage_t(torch.cat([x, s4_maps, cpool], dim=1))
    s6_maps = self.stage_t(torch.cat([x, s5_maps, cpool], dim=1))

    return torch.stack([s1_maps, s2_maps, s3_maps, s4_maps, s5_maps, s6_maps], dim=1)

In [None]:
from torch.utils.data import Dataset

class TrainDataProp(Dataset):
    def __init__(self):
        super(TrainDataProp, self).__init__()
        dir = os.path.join(dir_path, 'train_annotation.json')
        with open(dir) as f:
            dictionary = json.load(f)
            self.features = [item for item in dictionary['data']]

    def __getitem__(self, idx):
        f_d = self.features[idx]
        im_dir = os.path.join(dir_path, 'train', f_d['file'])
        im = mpimg.imread(im_dir)

        crop_size = 256
        hm_size = (64, 64)
        cm_size = (256, 256)

        c_im, c_jts, c_cen = crop_image(im, f_d, crop_size)
        hmaps, cmap = get_heatmap([np.zeros(hm_size), np.zeros(cm_size)], c_jts/4, c_cen)
        c_im = np.transpose(c_im, (2, 0, 1))/255.0

        return c_im, hmaps, cmap

    def __len__(self):
        return len(self.features)

In [None]:
from torch.utils.data import DataLoader

cuda = torch.cuda.is_available()
device = 'cuda:0' if cuda else 'cpu'
dir = os.path.join(dir_path, "cpm_proposed.pth")

epochs = 500
lr = 0.0001
batch_size = 24

def train():
  criterion = nn.MSELoss()
  optimizer = torch.optim.Adam(model.parameters(), lr=lr)
  loss_array = []
  best_loss = 0.0021
  model.train()

  train_data = DataLoader(TrainDataProp(), batch_size=batch_size, shuffle=True, num_workers=4)

  for epoch in range(epochs):
    if(epoch % 2 == 0):
      dir = os.path.join(dir_path, "cpm_proposed.pth")
      torch.save(model.state_dict(), dir)
      print("Saving Model.")
    for i, (im, heat, center) in enumerate(train_data):
      
      im = im.float().to(device)
      heat = torch.stack([heat]*6, dim=1) # stack one set of heatmaps for each stage of our CPM
      heat = heat.float().to(device)
      center = center.float().to(device)

      pred = model(im, center)
      loss = criterion(pred, heat)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      print('----------------- Epoch ' + str(epoch + 1) + ', batch ' + str(i + 1) + '/' + str(len(train_data)) + ', Loss: ' + str(loss.item()) + ' -----------------')
      if loss.item() < best_loss:
        dir = os.path.join(dir_path, "cpm_proposed_best.pth")
        best_loss = loss.item()
        torch.save(model.state_dict(), dir)
        print('----------------- Saving Best Loss CPM ----------------------')

    loss_array.append(loss)
    print(loss)
    
  #Save model one final time
  dir = os.path.join(dir_path, "cpm_proposed.pth")
  torch.save(model.state_dict(), dir)

In [None]:
model = ProposedCPM(17).to(device)

In [None]:
train()

In [None]:
dir = os.path.join(dir_path, "cpm_proposed.pth")
model.load_state_dict(torch.load(dir))
dir = os.path.join(dir_path, "cpm_proposed.pth")
model.eval()

ProposedCPM(
  (cpool): AvgPool2d(kernel_size=4, stride=4, padding=0)
  (x_conv1): Conv2d(3, 128, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (x_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (x_conv2): Conv2d(128, 128, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (x_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (x_conv3): Conv2d(128, 128, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (x_pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (x_conv4): Conv2d(128, 128, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (x_upcv1): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
  (x_conv5): Conv2d(192, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (x_bnorm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (s1_conv1): Conv2d(32, 512, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
  (s1_conv2): Con