In [1]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import numpy as np
from torchsummary import summary
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import torchvision.transforms as tf
import cv2
from PIL import Image

In [2]:
sys.path.append('/openpose')
from model import PoseEstimationWithMobileNet
from include import CocoKeypoints
import transforms

In [3]:
n_epoch = 5000
lr = 1e-3
batch = 32

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PoseEstimationWithMobileNet().to(device)
# model = nn.DataParallel(model).to(device)

criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(model.parameters(), lr = lr, weight_decay= 1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min')

In [5]:
preprocess = transforms.Compose([
        transforms.Normalize(),
        transforms.RandomApply(transforms.HFlip(), 0.5),
        transforms.RescaleRelative(),
        transforms.Crop(368),
        transforms.CenterPad(368),
    ])

In [6]:
data_dir = "/var/lib/docker/openpose/coco"
ann_train_dir = [os.path.join(data_dir, 'annotations', item) 
                 for item in ['person_keypoints_train2017.json']]
ann_val_dir = os.path.join(data_dir, 'images/train2017')

image_transform = None

datas = [CocoKeypoints(
            root=ann_val_dir,
            annFile=item,
            preprocess=preprocess,
            image_transform=image_transform,
            target_transforms=image_transform, 
        ) for item in ann_train_dir]

data = datas[0]
(image, heatmaps, pafs) = data[1]


# print(f"Filepath = {filepath}")
print(f"Shape of original image = {image.shape}")
print(f"Shape of original heatmaps = {heatmaps.shape}")
print(f"Shape of original pafs = {pafs.shape}")

# summary(model, image.shape)

loading annotations into memory...
Done (t=10.12s)
creating index...
index created!
filter for keypoint annotations ...
... done.
Images: 56599
Shape of original image = torch.Size([3, 368, 368])
Shape of original heatmaps = torch.Size([19, 46, 46])
Shape of original pafs = torch.Size([38, 46, 46])


In [7]:
fig = plt.figure(figsize = (75, 36))

im = image.permute(1, 2, 0)
im = np.array(im)
print(im.shape)


# # im = np.squeeze(image)
# # im_ = cm.jet(im)
# # im_ = Image.fromarray(np.uint8(im))
# # im_ = im_.convert('RGB')

# # heatmaps = heatmaps.permute(1,2, 0)
# heatmaps = np.array(heatmaps)
# heatmap = np.squeeze(heatmaps)  # output 1 is heatmaps
# heatmap = heatmap.swapaxes(0, 1).swapaxes(1,2)
# heatmap = cv2.resize(heatmap, (0, 0), fx=8, fy=8, interpolation=cv2.INTER_CUBIC)
# # heatmap = heatmap[:im.shape[0] - pad[2], :im.shape[1] - pad[3], :]
# heatmap = cv2.resize(heatmap, (im.shape[1], im.shape[0]), interpolation=cv2.INTER_CUBIC)

# max_value = np.max(heatmap)
# min_value = np.min(heatmap)
# label = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri', 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear', 'pt19'] 
# for i in range(18):
#     vis_img = (heatmap[:, :, i]-min_value)/max_value
#     vis_img = Image.fromarray(np.uint8(cm.jet(vis_img) * 255))
#     vis_img = vis_img.convert('RGB') 
#     vis_img = Image.blend(im_, vis_img, 0.8)
#     vis_img = np.array(vis_img)
#     plt.subplot(3,6,i+1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(vis_img)
#     plt.xlabel(label[i], fontsize=24)
     
# plt.show()
# plt.imshow(im_)


# ref_im = cv2.imread(filepath)
# i = [2, 1, 0]
# ref_im = ref_im[:,:,i]
# ref_im_ = Image.fromarray(np.uint8(ref_im))
# fig = plt.figure(figsize = (25,12))
# plt.show
# plt.imshow(ref_im_)

(368, 368, 3)


<Figure size 5400x2592 with 0 Axes>

In [8]:
# pafs = np.array(pafs)
# paf = np.squeeze(pafs)  # output 1 is heatmaps
# paf = cv2.resize(paf, (0, 0), fx=8, fy=8, interpolation=cv2.INTER_CUBIC)
# # paf = paf[:im.shape[0] - pad[2], :im.shape[1] - pad[3], :]
# paf = paf.swapaxes(0, 1).swapaxes(1,2)
# paf = cv2.resize(paf, (im.shape[1], im.shape[0]), interpolation=cv2.INTER_CUBIC)

# print(f"Shape of original image = {im.shape}")
# print(f"Shape of heatmap = {heatmap.shape}")
# print(f"Shape of paf = {paf.shape}")

# fig = plt.figure(figsize = (75, 36))
# max_value = np.max(paf)
# min_value = np.min(paf)
# for i in range(19):
#     vis_img_x = (paf[:, :, i*2]-min_value)/max_value
#     vis_img_y = (paf[:, :, i*2+1]-min_value)/max_value    
#     plt.subplot(3,7,i+1)
#     plt.xticks([])
#     plt.yticks([])
#     plt.grid(False)
#     plt.imshow(im, interpolation='nearest')
#     plt.imshow(vis_img_x, cmap=plt.cm.jet, alpha=0.3)
#     plt.imshow(vis_img_y, cmap=plt.cm.jet, alpha=0.3)
        
# plt.show()

In [9]:
dataset = torch.utils.data.ConcatDataset(datas)

n_train = 20000
n_val = 5000
train_data, val_data, _ = torch.utils.data.random_split(dataset, 
                                                (n_train, n_val, len(dataset)-n_train-n_val))
 
train_loader = torch.utils.data.DataLoader(
            train_data, 
            batch_size = batch, 
            shuffle = True,
            num_workers = 0,
            drop_last = True,
            )
val_loader = torch.utils.data.DataLoader(
            val_data, 
            batch_size = batch,
            shuffle = True,
            num_workers = 0,
            drop_last = True,
            )

In [None]:
%load_ext tensorboard
import datetime
from torch.utils.tensorboard import SummaryWriter
print(os.getcwd())
log_dir = '/openpose/runs/' + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
summary = SummaryWriter('/openpose/runs')
# !kill 643
%tensorboard --logdir /openpose/runs --port 6004 --bind_all

/tf


Launching TensorBoard...

In [None]:
# ti = time.time()
save_dir = "./best_model"

for epoch in range(n_epoch):
    best_val_loss = np.inf
    train_tot_loss = 0
    val_tot_loss = 0
    
    ######### TRAIN ########
    model.train()
    i = 0
    for i, (images, heatmaps, pafs) in enumerate(train_loader):
        i = i+1
        images = images.to(device)
        heatmaps = heatmaps.to(device)
        pafs = pafs.to(device)
    
        stages_output = model(images)
        paf = stages_output[-1].to(device=device, dtype=torch.long)
        heatmap = stages_output[-2].to(device=device, dtype=torch.long)
    
        train_loss = criterion(heatmaps, heatmap) + criterion(pafs, paf)
        
        optimizer.zero_grad()
        train_loss.requires_grad = True
        train_loss.backward()
        optimizer.step()
        
        train_tot_loss += train_loss.item()

        if i==1:
            break

#     print('out??')
    ######## VALIDATION ########
    with torch.no_grad():
        model.eval()
        for i, (images, heatmaps, pafs) in enumerate(val_loader):
            i = i+1
            images = images.to(device)
            heatmaps = heatmaps.to(device)
            pafs = pafs.to(device)
            
            stages_output = model(images)
            paf = stages_output[-1].to(device=device, dtype=torch.long)
            heatmap = stages_output[-2].to(device=device, dtype=torch.long)
            
            val_loss = criterion(heatmaps, heatmap) + criterion(pafs, paf)
            
            val_tot_loss += val_loss.item()
            
            summary.add_scalar('train_loss', train_loss, epoch)
            summary.add_scalar('val_loss', val_loss, epoch)

            if i==2:
                break

                
            
    scheduler.step(val_loss)
            
    
    
    # print status every 10 epoch
#     if epoch % 10 == 0 or epoch == n_epoch-1:
#         ti_ = time.time()
    
    if val_loss<best_val_loss:
        torch.save(model.state_dict(), save_dir)
    
# summary.close()
# input output 완성하기
# initialize
# git
# gpu