In [1]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
import datetime
import torch
from torch.nn import functional as F
import scipy.stats
import sklearn.datasets
import glob

import sys
sys.path.append("../")
import curvvae_lib.train.predictive_passthrough_trainer as ppttrainer
import curvvae_lib.architecture.passthrough_vae as ptvae
import curvvae_lib.architecture.load_model as lm

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-4597upsh because the default path (/home/tsr42/.cache/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def plot_closure():
    plt.show()
    #plt.close()

# Dataset

In [3]:
foodname = "banana"
foldername = f"fork_trajectory_{foodname}"
savefilename = f"{foodname}_clean_pickups"

In [4]:
train = []
training_ts = np.linspace(0,1,64)
attempt = 1
while True:
    try:
        raw_vals = np.load(f"{savefilename}/pickup_attempt{attempt}.npy")
        train.append(raw_vals.T.flatten())
    except:
        print(f"We found {attempt-1} pickup attempts")
        break
    attempt += 1

train = np.array(train).reshape(-1,7,64)
all_points = train[:,:,:]

time_shape = list(all_points.shape)
time_shape[1] = 1
# why be smart when you can be dumb
t = np.ones(time_shape)
for i in range(time_shape[2]):
    t[:,:,i] = t[:,:,i] * i / (time_shape[2] + 0.0)

all_points = np.concatenate((t, all_points), axis=1)
all_points = all_points.transpose(0,2,1)
print(all_points.shape)
all_poses = all_points.reshape(-1,8)

# See http://localhost:8889/notebooks/scratchwork/2021-09-17%20Rotation%20Scaling.ipynb
# for why we want quaternion values to be multiplied by 0.16 when position values are in meters 
# (if the relevant distance scale of the fork is 0.08 meters, ie: 8cm).
# scaling term doesn't affect time, so don't use time in calculation

mean = np.mean(all_poses, axis=0)
mean[0] = 0 # don't scale time
variance = np.var(all_poses[:,1:], axis=0) # don't scale time
print(mean)
print(variance)
position_std = np.sqrt(np.max(variance))
print("std of: ", position_std)
position_scaling = 1/position_std
rotation_scaling = 0.16 * position_scaling

We found 155 pickup attempts
(155, 64, 8)
[ 0.          0.00227658 -0.00117697  0.0121782  -0.04935249  0.37343377
 -0.89429268 -0.01921521]
[2.63014114e-05 3.40430938e-05 1.00819967e-04 7.90561700e-03
 3.18947674e-02 7.03375426e-03 1.11414372e-02]
std of:  0.17859106200728153


In [5]:
def print_to_csv(mean):
    string = ""
    for i,m in enumerate(mean):
        string += "%0.8f" % m
        if i < len(mean)-1:
            string+= ","
    return string
print(f"\
      mean = np.array(({print_to_csv(mean[1:])}))\n\
      ps = {position_scaling}\n\
      rs = {rotation_scaling}\n")
def scale_dataset(input_points):
    points = input_points - mean
    poss = position_scaling
    rts = rotation_scaling
    points = (points * np.array((1,poss,poss,poss,rts,rts,rts,rts)))
    return points

def unscale_poses(input_points):
    poss = position_scaling
    rts = rotation_scaling
    points = (input_points / np.array((poss,poss,poss,rts,rts,rts,rts)))
    points = points + mean[1:]
    return points 
    
def unscale_dataset(input_points):
    poss = position_scaling
    rts = rotation_scaling
    points = (input_points / np.array((1,poss,poss,poss,rts,rts,rts,rts)))
    points = points + mean
    return points

      mean = np.array((0.00227658,-0.00117697,0.01217820,-0.04935249,0.37343377,-0.89429268,-0.01921521))
      ps = 5.599384363139225
      rs = 0.895901498102276



In [6]:
dataset = scale_dataset(all_points)
t_all_points = torch.tensor(dataset,dtype=torch.float32)

## 2D VAE Fit to Dataset

In [7]:
def LoadDataBatch(all_points, batchsize, passthroughdim, predictive, device):
    """Sample Pair of Points from Trajectory"""
    # all_points should be of dimension: num_trajectories, numtimesteps, 1+spatialdims
    traj_ids = np.random.choice(all_points.shape[0], batchsize)
    t1_ids = np.random.choice(all_points.shape[1], batchsize)
    if predictive:
        t2_ids = np.random.choice(all_points.shape[1], batchsize)
    else:
        t2_ids = t1_ids
    return (torch.tensor(all_points[traj_ids, t1_ids,passthroughdim:], dtype=torch.float).to(device),
            torch.tensor(all_points[traj_ids, t1_ids,:passthroughdim], dtype=torch.float).to(device),
            torch.tensor(all_points[traj_ids, t2_ids,passthroughdim:], dtype=torch.float).to(device),
            torch.tensor(all_points[traj_ids, t2_ids,:passthroughdim], dtype=torch.float).to(device))

In [8]:
class Loader(object):
    def __init__(self, all_points, batchsize, passthroughdim, epochnumbatches, predictive, device):
        self.all_points = all_points
        self.batchsize = batchsize
        self.passthroughdim = passthroughdim
        self.epochnumbatches = epochnumbatches
        self.predictive = predictive
        self.device = device
        
    def __iter__(self):
        self.n = 0
        return self
    
    def __next__(self):
        if self.n >= self.epochnumbatches:
            # https://docs.python.org/3/library/exceptions.html#StopIteration
            raise StopIteration
        self.n += 1
        return LoadDataBatch(self.all_points, self.batchsize, self.passthroughdim, self.predictive, self.device)
    

device = "cuda"
train_loader = Loader(dataset[:,np.newaxis,:], 256, 0,10,predictive=False, device=device)
  

In [9]:
dataset_dim = dataset.shape[-1]
def make_vae(latent_dim):
    input_dim = dataset_dim 
    passthrough_dim = 0
    emb_layer_widths = [1000]
    recon_layer_widths = [1000]
    dtype = torch.float
    model = ptvae.FCPassthroughVAE(input_dim, passthrough_dim, latent_dim,
        emb_layer_widths, recon_layer_widths, dtype)
    return model

In [10]:
testname = f"trainedmodels/{foodname}_"

In [11]:
all_models = ['trainedmodels/banana_lat3_curvreg0_beta0.001_20220209-144826']
#all_models = [f'trainedmodels/banana_lat3_curvreg0.001_beta0.001_20220209-120436']
all_models = [f'trainedmodels/banana_lat3_curvreg0_beta0.01_20220209-152954']
all_models = [f'trainedmodels/banana_lat3_curvreg0_beta0.005_20220725-235649']
all_models = [f'trainedmodels/banana_lat3_curvreg0_beta0.002_20220725-191457']
all_models = [f'trainedmodels/banana_lat3_curvreg0_beta0.001_20220210-013953'] # 0.0001 learning rate


all_models

['trainedmodels/banana_lat3_curvreg0_beta0.001_20220210-013953']

In [12]:
latent_lim_values = np.array((2.,2,2))

In [13]:
modelname = all_models[0]
loaded_vae = lm.load_model(modelname)


In [14]:
import transforms3d as t3d

In [15]:
def loss_function(numpy_latent,target_angle, verbose=False):
    targ_ang_rads = 0.4 - (target_angle / 180. * np.pi)
    #print(targ_ang_rads)
    latent = torch.tensor(numpy_latent,dtype=torch.float32).reshape(1,3)
    t = torch.tensor((0),dtype=torch.float32).reshape(1,1)
    scaled_pose, _ = loaded_vae.decode(latent,t)
    pose = unscale_poses(scaled_pose.detach().cpu().numpy())
    quat = pose[0,3:]
    base_quat = (0,0,1,0)
    quat = t3d.quaternions.qmult(base_quat,quat)
    eul = np.array(t3d.euler.quat2euler(quat,axes='sxyz'))
    if verbose:
        print(eul * 180/np.pi)
    return np.sum(np.square(eul - (targ_ang_rads,0,-0.79478576)) * np.array((1,0.0001,0.0001))) # ok for a little side to side, but gotta be correct tilt

In [16]:
# https://pyswarms.readthedocs.io/en/latest/examples/usecases/train_neural_network.html#Constructing-a-custom-objective-function
def loss_function_batch(x, target_angle):
    """Higher-level method to do forward_prop in the
    whole swarm.

    Inputs
    ------
    x: numpy.ndarray of shape (n_particles, dimensions)
        The swarm that will perform the search

    Returns
    -------
    numpy.ndarray of shape (n_particles, )
        The computed loss for each particle
    """
    n_particles = x.shape[0]
    j = [loss_function(x[i],target_angle) for i in range(n_particles)]
    return np.array(j)

In [17]:
import pyswarms as ps

final_answers = []
for target_ang in [30,45,60]:
    lim = 3
    curbest = None
    while curbest is None:# or (np.sqrt(curbest[0])/np.pi) * 180 > 1: # keep expanding the search radius until within 1 degree of desired angle
        #lim += 0.5
        for _ in range(10):
            # Set-up hyperparameters
            options = {'c1': 0.5, 'c2': 0.3, 'w':0.9}

            # Call instance of GlobalBestPSO
            optimizer = ps.single.GlobalBestPSO(n_particles=30, dimensions=3,
                                                options=options, bounds=((-lim,-lim,-lim),(lim,lim,lim)))

            # Perform optimization
            stats = optimizer.optimize(loss_function_batch, target_angle=target_ang, iters=200)

            if curbest is None or stats[0] < curbest[0]:
                print("updating best")
                curbest = stats
    final_answers.append(curbest)
print(final_answers)

2022-07-26 10:12:41,617 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=5.66e-8
2022-07-26 10:12:42,805 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 5.660494125043583e-08, best pos: [-1.61421298  0.15577061  0.96836048]
2022-07-26 10:12:42,811 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}


updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=8.72e-9
2022-07-26 10:12:43,983 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 8.716868370595702e-09, best pos: [-1.34376667  0.04385996  1.02353547]
2022-07-26 10:12:43,989 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}


updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=3.34e-8
2022-07-26 10:12:45,163 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 3.339349944806831e-08, best pos: [-1.56518588  0.12916739  0.92959041]
2022-07-26 10:12:45,168 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=7.62e-7
2022-07-26 10:12:46,339 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 7.624822992110806e-07, best pos: [-1.45283283  0.24132446  0.53467268]
2022-07-26 10:12:46,345 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=7.76e-8
2022-07-26 10:12:47,550 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 7.763183105901939e-08, best pos: [-1.05430319 -0.10557305  0.71277429]
2022-07-26 10:12:47,555 - pys

updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=3.89e-8
2022-07-26 10:12:49,919 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 3.885660686226928e-08, best pos: [-2.67677351  0.01682844  1.78251574]
2022-07-26 10:12:49,925 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=8.91e-8
2022-07-26 10:12:51,116 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 8.905351808020522e-08, best pos: [-1.09394639 -0.08723269  0.85690324]
2022-07-26 10:12:51,122 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=1.87e-8
2022-07-26 10:12:52,335 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 1.8677246238489106e-08, best pos: [-1.53280986  0.11903485  0.95173328]
2022-07-26 10:12:52,340 - py

updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=6.41e-6
2022-07-26 10:12:55,877 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 6.406497258950826e-06, best pos: [-1.87550143  0.52190007 -0.55008877]
2022-07-26 10:12:55,883 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}


updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=6.48e-6
2022-07-26 10:12:57,079 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 6.4830138590770825e-06, best pos: [-1.84114946  0.51333029 -0.57980371]
2022-07-26 10:12:57,084 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=6.43e-6
2022-07-26 10:12:58,266 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 6.4327450487177785e-06, best pos: [-1.83008563  0.52372974 -0.55773673]
2022-07-26 10:12:58,272 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=6.43e-6
2022-07-26 10:12:59,465 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 6.430448259220071e-06, best pos: [-1.87486008  0.52728867 -0.54287637]
2022-07-26 10:12:59,470 - p

updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=6.27e-6
2022-07-26 10:13:03,145 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 6.274763751310802e-06, best pos: [-2.5236909   0.61579484 -0.48340507]
2022-07-26 10:13:03,150 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=6.46e-6
2022-07-26 10:13:04,327 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 6.4557212810624605e-06, best pos: [-1.82801344  0.51833381 -0.56989319]
2022-07-26 10:13:04,332 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=1.11e-5
2022-07-26 10:13:05,505 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 1.1119939694871437e-05, best pos: [-1.48141531  0.93526411 -1.19215126]
2022-07-26 10:13:05,511 - p

updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=2.6e-5 
2022-07-26 10:13:07,874 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 2.5994185768708246e-05, best pos: [-2.91579426  1.37205946 -1.04435788]
2022-07-26 10:13:07,880 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=2.02e-5
2022-07-26 10:13:09,062 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 2.020989752996988e-05, best pos: [-2.6349394   1.08102132 -1.36049144]
2022-07-26 10:13:09,068 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}


updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=1.95e-5
2022-07-26 10:13:10,254 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 1.9472268620205732e-05, best pos: [-2.75141325  0.81039103 -1.728153  ]
2022-07-26 10:13:10,259 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}


updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=2.87e-5
2022-07-26 10:13:11,432 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 2.871103296729191e-05, best pos: [-2.1192238   1.18386023 -1.73009294]
2022-07-26 10:13:11,437 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=1.9e-5
2022-07-26 10:13:12,681 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 1.9003079360508786e-05, best pos: [-2.85998382  0.97312894 -1.33378041]
2022-07-26 10:13:12,687 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}


updating best


pyswarms.single.global_best: 100%|██████████|200/200, best_cost=1.92e-5
2022-07-26 10:13:13,873 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 1.9239327417025177e-05, best pos: [-2.83045351  0.92070795 -1.41675606]
2022-07-26 10:13:13,878 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=2.92e-5
2022-07-26 10:13:15,067 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 2.9246227779658068e-05, best pos: [-2.1673636   1.07076798 -1.81496584]
2022-07-26 10:13:15,072 - pyswarms.single.global_best - INFO - Optimize for 200 iters with {'c1': 0.5, 'c2': 0.3, 'w': 0.9}
pyswarms.single.global_best: 100%|██████████|200/200, best_cost=1.96e-5
2022-07-26 10:13:16,303 - pyswarms.single.global_best - INFO - Optimization finished | best cost: 1.9580031681089798e-05, best pos: [-2.82242624  1.01535659 -1.29644694]
2022-07-26 10:13:16,308 - 

[(1.1137329049822525e-09, array([-1.35876521,  0.0551884 ,  1.01229962])), (6.104298232006867e-06, array([-2.90285753,  0.59823158, -0.50272271])), (1.9003079360508786e-05, array([-2.85998382,  0.97312894, -1.33378041]))]


In [18]:
for fb in final_answers:
    print(np.sqrt(fb[0]) / np.pi * 180)

0.001912111255312505
0.14155998374681178
0.24976675037813853


In [19]:
for fb in final_answers:
    print(f"[{','.join([f'{s}' for s in fb[1]])}],")

[-1.3587652105334609,0.055188400951311375,1.012299616765987],
[-2.9028575328539272,0.598231580485021,-0.502722708889766],
[-2.859983816519404,0.9731289363805093,-1.3337804063300918],


In [20]:
for fb in final_answers:
    loss_function(fb[1],1000,True)


[-7.08164935e+00 -6.11504154e-03 -4.57289435e+01]
[-22.10831043  -5.32375707 -32.6940981 ]
[-37.07781166  -4.37200122 -20.94987206]


In [21]:
0.4*180/np.pi

22.91831180523293

In [22]:
"""
[0.2478165554759397,-1.249154366817157,-1.5198855746754794],
[0.10129095212372306,-2.8613022436706315,0.9409312592318927],
[0.5889920784970968,-2.3291049154306642,2.6950374180789747],
"""

'\n[0.2478165554759397,-1.249154366817157,-1.5198855746754794],\n[0.10129095212372306,-2.8613022436706315,0.9409312592318927],\n[0.5889920784970968,-2.3291049154306642,2.6950374180789747],\n'