In [1]:
import os
import sys
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import numpy as np
from torchsummary import summary
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torchvision.transforms as tf
import cv2
from PIL import Image
from tqdm.auto import tqdm

In [2]:
sys.path.append('/openpose')
from model import PoseEstimationWithMobileNet, example
from include import CocoKeypoints
import transforms

In [3]:
n_epoch = 10000
lr = 1e-3
batch = 32

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PoseEstimationWithMobileNet()
model = nn.DataParallel(model).to(device)

criterion = nn.MSELoss().to(device)
optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay= 1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min')

In [6]:
preprocess = transforms.Compose([
        transforms.Normalize(),
        transforms.RandomApply(transforms.HFlip(), 0.5),
        transforms.RescaleRelative(),
        transforms.Crop(368),
        transforms.CenterPad(368),
    ])

In [7]:
data_dir = "/var/lib/docker/openpose/coco"
ann_train_dir = [os.path.join(data_dir, 'annotations', item) 
                 for item in ['person_keypoints_val2017.json']]
ann_val_dir = os.path.join(data_dir, 'images/val2017')

image_transform = None
n_train = 200
n_val = 50
datas = [CocoKeypoints(
            root=ann_val_dir,
            annFile=item,
            preprocess=preprocess,
            image_transform=image_transform,
            target_transforms=image_transform, 
            n_images = n_train+n_val,
        ) for item in ann_train_dir]

loading annotations into memory...
Done (t=0.44s)
creating index...
index created!
filter for keypoint annotations ...
... done.
Images: 250


In [8]:
# data = datas[0]
# (image, heatmaps, pafs), filepath = data[0]
# # print(f"Shape of original image = {im.shape}")
# # print(f"Shape of heatmap = {heatmap.shape}")
# # print(f"Shape of paf = {paf.shape}")

# # reference image
# fig = plt.figure()
# ref_image = Image.open(filepath)
# plt.imshow(ref_image)

# # image from dataloader
# im = image.numpy()
# im = np.transpose(im, (1, 2, 0))
# max_value_i = np.max(im)
# min_value_i = np.min(im)
# print(f"max:{max_value_i}, min:{min_value_i}")
# n_im = (im - min_value_i)/(max_value_i-min_value_i)
# fig = plt.figure()
# plt.imshow(n_im)

In [9]:
# heatmap = heatmaps.numpy()
# heatmap = np.transpose(heatmap, (1, 2, 0))
# heatmap = cv2.resize(heatmap, (0, 0), fx=8, fy=8, interpolation=cv2.INTER_CUBIC)
# max_value = np.max(heatmap)
# min_value = np.min(heatmap)

# fig = plt.figure(figsize = (50, 24))
# im_ = Image.fromarray(im.astype(np.uint8)*255)
# im_ = im_.convert('RGB')
# label = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 'Rhip', 
#          'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear', 'pt19'] 
# for i in range(18):
#     vis_img = (heatmap[:, :, i]-min_value)/max_value
#     vis_img = Image.fromarray(np.uint8(cm.jet(vis_img) * 255))
#     vis_img = vis_img.convert('RGB') 
#     vis_img = Image.blend(im_, vis_img, 0.5)
#     vis_img = np.array(vis_img)
#     plt.subplot(3,6,i+1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(vis_img)
#     plt.xlabel(label[i], fontsize=24)

In [10]:
# paf = pafs.numpy()
# paf = np.transpose(paf, (1, 2, 0))
# paf = cv2.resize(paf, (0, 0), fx=8, fy=8, interpolation=cv2.INTER_CUBIC)

# fig = plt.figure(figsize = (50, 24))
# max_value = np.max(paf)
# min_value = np.min(paf)
# for i in range(19):
#     vis_img_x = (paf[:, :, i*2]-min_value)/max_value 
#     vis_img_y = (paf[:, :, i*2+1]-min_value)/max_value /255   
#     plt.subplot(3,7,i+1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(im, interpolation='nearest')
#     plt.imshow(vis_img_x, cmap=plt.cm.jet, alpha=0.3)
#     plt.imshow(vis_img_y, cmap=plt.cm.jet, alpha=0.3)

In [11]:
# summary(model, image.shape)

In [12]:
dataset = torch.utils.data.ConcatDataset(datas)

train_data, val_data = torch.utils.data.random_split(dataset, (n_train, n_val))
 
train_loader = torch.utils.data.DataLoader(
            train_data, 
            batch_size = batch, 
            shuffle = True,
            num_workers = 8,
            drop_last = True,
            )
val_loader = torch.utils.data.DataLoader(
            val_data, 
            batch_size = batch,
            shuffle = True,
            num_workers = 8,
            drop_last = True,
            )

In [13]:
%load_ext tensorboard
import datetime
from torch.utils.tensorboard import SummaryWriter
log_dir = '/openpose/runs/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
summary = SummaryWriter('/openpose/runs')
%tensorboard --logdir /openpose/runs --port 6004 --bind_all

Reusing TensorBoard on port 6004 (pid 24070), started 13:14:38 ago. (Use '!kill 24070' to kill it.)

In [14]:
ti = time.time()
save_dir = "./best_model"
best_val_loss = np.inf
for epoch in tqdm(range(n_epoch)):    
    train_tot_loss = 0
    val_tot_loss = 0
    ######### TRAIN ########
    model.train()
    i = 0
    for i, ((images, heatmaps, pafs), _) in enumerate(train_loader):
        i = i+1
        images = images.to(device)
        heatmaps = heatmaps.to(device)
        pafs = pafs.to(device)
        
        optimizer.zero_grad()
        stages_output = model(images)
        paf = stages_output[-1]
        heatmap = stages_output[-2]
    
#         hm_loss = criterion(heatmaps, heatmap)
#         paf_loss = criterion(pafs, paf)
#         train_tot_loss += hm_loss
#         train_tot_loss += paf_loss
        train_tot_loss = criterion(heatmaps, heatmap) + criterion(pafs, paf)
        train_tot_loss.requires_grad = True
        train_tot_loss.backward()
        
        print(f"tot_loss: {train_tot_loss}")
        optimizer.step()
        
    summary.add_scalar('train_loss', train_tot_loss, epoch)

#     ######## VALIDATION ########
    with torch.no_grad():
        model.eval()
        i=0
        for i, ((images, heatmaps, pafs), filepath) in enumerate(val_loader):
            i = i+1
            images = images.to(device)
            heatmaps = heatmaps.to(device)
            pafs = pafs.to(device)
            
            stages_output = model(images)
            paf = stages_output[-1].to(device=device, dtype=torch.long)
            heatmap = stages_output[-2].to(device=device, dtype=torch.long)
            
            val_loss = criterion(heatmaps, heatmap) + criterion(pafs, paf)
            val_tot_loss += val_loss
            
    scheduler.step(val_tot_loss)
    summary.add_scalar('val_loss', val_tot_loss, epoch)    
    if val_tot_loss<best_val_loss:
        torch.save(model.state_dict(), save_dir)
   
    # print status every 10 epoch
    if epoch % 10 == 0 or epoch == n_epoch-1:
        ti2 = time.time()
        ti_ = ti2-ti
        for param in optimizer.param_groups:
            print(f"for {epoch} epoch, {int(ti_/60)}min {int(ti_%60)}sec elapsed, lr = {param['lr']}")

# summary.close()

  0%|          | 0/10000 [00:00<?, ?it/s]

RuntimeError: you can only change requires_grad flags of leaf variables.

In [None]:
# def create_samples(self, z, y=None):
#         self.generator.eval()
#         batch_size = z.size(0)
#         # Parse y
#         if y is None:
#             y = self.ydist.sample((batch_size,))
#         elif isinstance(y, int):
#             y = torch.full((batch_size,), y,
#                            device=self.device, dtype=torch.int64)
#         # Sample x
#         with torch.no_grad():
#             x = self.generator(z, y)
#         return x

    
    
#         outdir = os.path.join(self.img_dir, class_name)
#         if not os.path.exists(outdir):
#             os.makedirs(outdir)
#         outfile = os.path.join(outdir, '%08d.png' % it)

#         imgs = imgs / 2 + 0.5
#         imgs = torchvision.utils.make_grid(imgs)
#         torchvision.utils.save_image(imgs, outfile, nrow=8)

#         if self.monitoring == 'tensorboard':
#             self.tb.add_image(class_name, imgs, it)

In [None]:
    heatmap = heatmaps[0]
    device2 = torch.device("cpu")
    heatmap = heatmap.to(device2)
    heatmap = heatmap.detach()
    heatmap = heatmap.numpy()
    
    heatmap = np.transpose(heatmap, (1, 2, 0)) # (46, 46, 19)
#     heatmap = cv2.resize(heatmap, (0, 0), fx=8, fy=8, interpolation=cv2.INTER_CUBIC)
    max_value = np.max(heatmap)
    min_value = np.min(heatmap)
    label = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 'Rhip', 
             'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear', 'pt19'] 
    for i in range(18):
        vis_img = (heatmap[:, :, i]-min_value)/max_value
        vis_img = Image.fromarray(np.uint8(cm.jet(vis_img) * 255))
        vis_img = vis_img.convert('RGB')  
        vis_img = np.array(vis_img) #(46, 46, 3)
#         vis_img = np.transpose(heatmap, (2, 0, 1))
#         summary.add_image('ground truth', vis_img)
#         print(vis_img.shape)
#         if i == 0:
#             vis = vis_img
#         else:
#             vis = np.concatenate((vis, vis_img), axis=-1)
#         print(f"{vis.shape} of epoch")
#     vis = np.reshape(vis, (46, 46, 3, 18))
#     vis = np.transpose(vis, (2, 0, 1, 3))
    
#     summary.add_image('ground truth', vis, max_outputs=18, step = 1)
    