In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import transforms3d as t3d
import datetime
import torch
from torch.nn import functional as F

import sys
sys.path.append("../")
import curvvae_lib.train.predictive_passthrough_trainer as ppttrainer
import curvvae_lib.architecture.passthrough_vae as ptvae
import curvvae_lib.architecture.save_model as sm
import curvvae_lib.architecture.load_model as lm

In [2]:
foodname = "banana"
foldername = f"fork_trajectory_{foodname}"
savefilename = f"{foodname}_clean_pickups"

In [3]:
train = []
training_ts = np.linspace(0,1,64)
attempt = 1
while True:
    try:
        raw_vals = np.load(f"{savefilename}/pickup_attempt{attempt}.npy")
        train.append(raw_vals.T.flatten())
    except:
        print(f"We found {attempt-1} pickup attempts")
        break
    attempt += 1

train = np.array(train).reshape(-1,7,64)
all_points = train[:,:,:]

time_shape = list(all_points.shape)
time_shape[1] = 1
# why be smart when you can be dumb
t = np.ones(time_shape)
for i in range(time_shape[2]):
    t[:,:,i] = t[:,:,i] * i / (time_shape[2] + 0.0)

all_points = np.concatenate((t, all_points), axis=1)
all_points = all_points.transpose(0,2,1)
print(all_points.shape)

# See http://localhost:8889/notebooks/scratchwork/2021-09-17%20Rotation%20Scaling.ipynb
# for why we want quaternion values to be multiplied by 0.16 when position values are in meters 
# (if the relevant distance scale of the fork is 0.08 meters, ie: 8cm).
# scaling term doesn't affect time, so don't use time in calculation
stats_reshaped = all_points.reshape(-1,8)
mean = np.mean(stats_reshaped, axis=0)
mean[0] = 0 # don't scale time
variance = np.var(stats_reshaped[:,1:], axis=0) # don't scale time
print(mean)
print(variance)
position_std = np.sqrt(np.max(variance))
print("std of: ", position_std)
position_scaling = 1/position_std
rotation_scaling = 0.16 * position_scaling

We found 179 pickup attempts
(179, 64, 8)
[ 0.          0.0021176  -0.00108136  0.01226278 -0.05341325  0.3809648
 -0.86529588 -0.02303113]
[2.74856348e-05 3.50234479e-05 1.02086472e-04 1.03992603e-02
 5.06524753e-02 2.98499090e-02 1.16367323e-02]
std of:  0.2250610480383746


In [4]:
def scale_dataset(input_points):
    points = input_points - mean
    poss = position_scaling
    rts = rotation_scaling
    points = (points * np.array((1,poss,poss,poss,rts,rts,rts,rts)))
    return points
    
def unscale_dataset(input_points):
    poss = position_scaling
    rts = rotation_scaling
    points = (input_points / np.array((1,poss,poss,poss,rts,rts,rts,rts)))
    points = points + mean
    return points

In [5]:
scaled_points = scale_dataset(all_points)
print(np.mean(scaled_points.reshape(-1,8),axis=0))
print(np.var(scaled_points.reshape(-1,8),axis=0))

[ 4.92187500e-01 -2.81528009e-17 -2.17591502e-17  2.48869826e-16
 -5.98249433e-18 -2.07415750e-16 -3.70239233e-15 -2.10880351e-17]
[0.08331299 0.00054263 0.00069145 0.00201543 0.00525584 0.0256
 0.01508628 0.00588126]


In [6]:
import os
foldname = f"pcamodels"
os.makedirs(foldname,exist_ok=True)
testname = f"{foldname}/{foodname}_"

In [7]:
latentdim=3
savedir  = f'{testname}lat{latentdim}_pca_{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}'
print(savedir)

pcamodels/banana_lat3_pca_20221031-110710


In [8]:
print(scaled_points.shape)
flattened_points = scaled_points[:,:,1:].reshape(scaled_points.shape[0],-1)
print(flattened_points.shape) # datapoints x (timepoints * joints)

(179, 64, 8)
(179, 448)


In [9]:
mean = np.mean(flattened_points,axis=0)
shifted = flattened_points - mean
u,s,vt = np.linalg.svd(shifted)
pca_components = s.reshape(-1,1)[:latentdim,:] * vt[:latentdim,:]
pca_components = pca_components.T
print(pca_components.shape) # want to save a (timepoints * joints) X latent_dim matrix

(448, 3)


In [10]:
print(mean.shape)

(448,)


In [11]:
np.savez(savedir, pca_components=pca_components, mean=mean)

In [12]:
pca_components = np.load(savedir+".npz")

In [13]:
pca_components["mean"]
pca_components["pca_components"]

array([[-0.01790024,  0.16757963,  0.09063121],
       [-0.11341525,  0.0888046 , -0.09900836],
       [-0.08943301, -0.00847028, -0.05994208],
       ...,
       [-1.76284471,  0.33714694,  0.01451369],
       [-0.99938237, -0.29799988,  0.21099649],
       [ 0.1159829 ,  0.07750834, -0.67683254]])