In [4]:
import numpy as np
from base_notebook.pose_data_tools.graph import Graph
from base_notebook.pose_data_tools.generate_data import read_xyz
from src.preprocessing.pre_normaliser import preNormaliser
import matplotlib.pyplot as plt
from matplotlib import animation, rc
from IPython.display import HTML
import os
import pandas as pd
import pickle

In [5]:
PATH_LABELS = r"_input/train_label.pkl"
PATH_RAW_TXT = r"_input/train/"

# two axes to plot choose from (0, 1, 2)
AXES = 0, 2

# class label (0-154) and samples number (~0-400)
LABEL, SAMPLE = 10, 5

In [6]:
with open(PATH_LABELS, 'rb') as f:
    labels = pickle.load(f)
    df = pd.DataFrame(np.array([[x.split('/')[-1] for x in labels[0]], labels[1]]).T, columns=['filename', 'label'])
    df['label'] = df['label'].astype(np.int)
    
permitted = [LABEL]
idx = df.loc[df['label'].isin(permitted)].index
fname = os.path.join(PATH_RAW_TXT, df.loc[idx, 'filename'].iloc[SAMPLE])

In [7]:
# uncomment to specify the filename explicitly
fname = r"../iccv2021/_input/train/P117S05G10B30H40UC021000LC021000A063R0_10031037.txt"

In [None]:
data = read_xyz(fname, 2, 17)
data = data.reshape(1, *data.shape)

# if second person is missing
if np.all(data[0, :, :, :, 1] == 0):
    data = data[:, :, :, :, :1]
    
# data = pre_normalization(data)[0]
# visualize(data)

In [None]:
encoder, decoder, vae = build_vae_model()
X = data.transpose(0, 2, 1, 3, 4).reshape((1, 77, -1))
vae(X)

In [150]:
def visualize(data):

    N, C, T, V, M = data.shape

    x0, x1 = np.min(data[:, AXES[0], :, :, :]), np.max(data[:, AXES[0], :, :, :])
    y0, y1 = np.min(data[:, AXES[1], :, :, :]), np.max(data[:, AXES[1], :, :, :])

    ratio = (y1 - y0) / (x1 - x0)

    size = 3
    
    xh = size
    yh = ratio * size

    graph = Graph()
    fig, ax = plt.subplots(figsize=(xh, yh))

    plt.xlim((x0, x1))
    plt.ylim((y0, y1))

    edge = graph.inward

    p_type = ['b-', 'g-', 'g-', 'c-', 'm-', 'y-', 'k-', 'k-', 'k-', 'k-']
    pose = []

    for m in range(M):
        a = []
        for i in range(len(edge)):
            a.append(ax.plot(np.zeros(2), np.zeros(2), p_type[m])[0])
        pose.append(a)

    def animate(t):

        for m in range(M):

            for i, (v1, v2) in enumerate(edge):
                x1 = data[0, AXES, t, v1, m]
                x2 = data[0, AXES, t, v2, m]
                if (x1.sum() != 0 and x2.sum() != 0) or v1 == 1 or v2 == 1:
                    pose[m][i].set_xdata(data[0, AXES[0], t, [v1, v2], m])
                    pose[m][i].set_ydata(data[0, AXES[1], t, [v1, v2], m])

        return np.array(pose).flatten()

    anim = animation.FuncAnimation(fig, animate, frames=T, interval=20, blit=True)
    return HTML(anim.to_html5_video())