# Libraries

In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
from glob import glob
from PIL import Image
import os, json, cv2, torch, yaml
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import sys
sys.path.append("../")
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

from lib.config import cfg
import lib.dataset as dataset
from lib.utils.draw import draw_heatmaps

# Build datasets

In [None]:
# Load config file
config_file = "../experiments/coco/hrnet/debug39.yaml"
cfg.defrost()
cfg.merge_from_file(config_file)
cfg.freeze()

# Transforms
mean=[0.485, 0.456, 0.406]
std=[0.229, 0.224, 0.225]
normalize = transforms.Normalize(mean, std)

# Build train_dataset
train_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
    cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True,
    transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
)

# Build valid_dataset
valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)(
    cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False,
    transforms.Compose([
        transforms.ToTensor(),
        normalize,
    ])
)

In [None]:
print("train_dataset:", len(train_dataset))
print("valid_dataset:", len(valid_dataset))

# Visualize train_dataset

In [None]:
# Get random_idx
random_idx = np.random.randint(0, len(train_dataset))
print("random_idx", random_idx)

# Get data
data = train_dataset.__getitem__(random_idx)
input, target, target_weight, meta = data
print("input", input.shape)
print("target", target.shape)
print("target_weight", target_weight.shape)
print(target_weight.numpy()[:,0])
print(meta.keys())

# Get image
m = np.array(mean)
s = np.array(std)
image = input.numpy().transpose((1,2,0))
image = (255*(image*s[None,None,:]+m[None,None,:])).astype('uint8')

# Get heatmaps
heatmaps = target.numpy().transpose((1,2,0)).astype('uint8')

# Visualize
drawn_image = draw_heatmaps(image, heatmaps)
plt.figure(figsize=(16,8))
plt.subplot(1,2,1); plt.imshow(image); plt.title("image")
plt.subplot(1,2,2); plt.imshow(drawn_image); plt.title("drawn_image")
plt.show()

# Visualize valid_dataset

In [None]:
# Get random_idx
random_idx = np.random.randint(0, len(valid_dataset))
print("random_idx", random_idx)

# Get data
data = valid_dataset.__getitem__(random_idx)
input, target, target_weight, meta = data
print("input", input.shape)
print("target", target.shape)
print("target_weight", target_weight.shape)
print(target_weight.numpy()[:,0])
print(meta.keys())

# Get image
m = np.array(mean)
s = np.array(std)
image = input.numpy().transpose((1,2,0))
image = (255*(image*s[None,None,:]+m[None,None,:])).astype('uint8')

# Get heatmaps
heatmaps = target.numpy().transpose((1,2,0)).astype('uint8')

# Visualize
drawn_image = draw_heatmaps(image, heatmaps)
plt.figure(figsize=(16,8))
plt.subplot(1,2,1); plt.imshow(image); plt.title("image")
plt.subplot(1,2,2); plt.imshow(drawn_image); plt.title("drawn_image")
plt.show()