In [1]:
%matplotlib inline
import sys
sys.path.append("..")
sys.path.append("../train_model")

import glob
import torch
import numpy as np
import matplotlib.pyplot as plt
import timewarp_lib.load_model as lm
import timewarp_lib.vector_timewarpers as vt

import transforms3d as t3d
from stl import mesh
import mpl_toolkits.mplot3d as mplot3d

#import cpp_dtw as cdtw

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
datapath = f'../forkdata/forkTrajectoryData.npz'
model_path_template = "../results/retrainedforkdata/{timestr}/savedmodel"
rawdata = np.load(datapath)

In [3]:
summary=[]
testsummary=[]
train = rawdata["train"]
test = rawdata["test"]
TRAJLEN = 200
train_tensor = torch.tensor(rawdata["train"],dtype=torch.float)


In [4]:
def load_this_model(modelname,train):
    model = lm.LoadedModel(modelname)
    ts = torch.tensor(np.linspace(0,1,train.shape[1]).reshape((1,train.shape[1],1)), dtype=torch.float).expand((train.shape[0],train.shape[1],1))
    recons, mu, logvar, scaled_ts = model.model.noiseless_forward(train,ts)

    return mu, recons, scaled_ts, model

In [5]:
from IPython.display import display

def plot_1d_sweep(modelname, sweepaxis=0,num_trajs_to_plot=201):
    cmap = plt.get_cmap("viridis")
    model = lm.LoadedModel(modelname)
    print(model.modeldata)
    # Create a new plot
    model_data_path = str(model.modeldata["datafile"])
    modeldata_info = np.load("../"+ model_data_path)
    pose_scaling = modeldata_info["pose_scaling"]
    pose_mean = modeldata_info["pose_mean"]
    train = torch.tensor(rawdata["train"],dtype=torch.float)
    embed, recons, scaled_ts, model = load_this_model(modelname,train)
    #print(embed)
    latent_dim = model.modeldata["latent_dim"]
    num_timesteps = TRAJLEN
    
    scaled_ts = scaled_ts.detach().cpu().numpy()
    npts = np.mean(scaled_ts,axis=0,keepdims=True)
    scaled_ts = torch.tensor(npts,dtype=torch.float).expand((num_trajs_to_plot,num_timesteps,1))
    
    embednp = embed.detach().cpu().numpy()
    sweep_embed = np.zeros(shape=(num_trajs_to_plot,latent_dim))    
    for i in range(latent_dim):
        sweep_embed[:,i] = np.median(embednp[:,i])
    med_sweep_embed = torch.tensor(sweep_embed,dtype=torch.float)
    med_traj = model.model.decoder.decode(med_sweep_embed[:1],scaled_ts[:1]).detach().numpy()

    unscale_recons = (med_traj / pose_scaling) + pose_mean
    
    #xs = unscale_recons[0,:TRAJLEN,0]
    #ys = unscale_recons[0,:TRAJLEN,1]
    #zs = unscale_recons[0,:TRAJLEN,2]
    #axes.plot(xs,ys,zs,color="k",linewidth=10)
    #
    median_result_trajs = model.model.decoder.decode(med_sweep_embed,scaled_ts).detach().numpy()
    for i in range(latent_dim):
        sweep_embed[:,i] = np.median(embednp[:,i])
    sweep_embed[:,sweepaxis] = np.percentile(embednp[:,sweepaxis],np.linspace(2,98,num_trajs_to_plot).astype(int))
    
    sweep_embed = torch.tensor(sweep_embed,dtype=torch.float)
    
    result_trajs = model.model.decoder.decode(sweep_embed,scaled_ts).detach().numpy()

    
    sweep_embed = sweep_embed.detach().cpu().numpy()
    recons = result_trajs#.detach().cpu().numpy()
    train = train.detach().cpu().numpy()
    
    
    # matplotlib doesn't do raytracing, so we gotta manually plot in order
    
    minembed = np.min(sweep_embed[:,sweepaxis]) if sweep_embed.shape[1]>0 else 0
    maxembed = np.max(sweep_embed[:,sweepaxis]) if sweep_embed.shape[1]>0 else 0
    order = np.argsort(-sweep_embed[:,sweepaxis]) if model.modeldata["latent_dim"] > 0 else np.arange(len(train))
    for rto, raytracing_order in enumerate([order]):#,np.argsort(embed.flatten())]):
        for name,data in [("Reconstructed",result_trajs)]: #
            th = np.linspace(0,np.pi*2,200)
            radius = 84/1000
            center = [0.4,0.7,0.017]

            unscale_recons = (data / pose_scaling) + pose_mean

            for framenumber, trajid in enumerate(raytracing_order):
                figure = plt.figure(figsize=(10,10))
                axes = plt.axes(projection='3d')
                axes.plot(np.sin(th)*radius + center[0], np.cos(th)*radius+center[1],0*th+center[2],c=((0,0,0,0.5)))
                #if trajid % 2 != 0:
                #    continue
                xs = unscale_recons[trajid,:TRAJLEN,0]
                ys = unscale_recons[trajid,:TRAJLEN,1]
                zs = unscale_recons[trajid,:TRAJLEN,2]
                colorval = cmap((sweep_embed[trajid,sweepaxis]-minembed)/(maxembed-minembed)) if embed.shape[1] > 0  else "black"
                
                colorval = (colorval[0],colorval[1],colorval[2],0.8)
                # backfill in white first
                axes.plot(xs,ys,zs,color=(1,1,1),linewidth=7)
                axes.plot(xs,ys,zs,color=colorval,linewidth=7)
               
                
                for i in [0]:#[int(t) for t in np.linspace(0,100-1,5)]:
                    #if trajid % 5 == 0:
                        translation = np.array((xs[i],ys[i],zs[i])).reshape(3,1)
                        quaternion = unscale_recons[trajid,i,3:7] 
                        #print(quaternion)
                        rotmat = t3d.quaternions.quat2mat(quaternion)

                        transformmat = np.concatenate((
                                          np.concatenate((rotmat,translation),axis=1),
                                          np.array((0,0,0,1)).reshape((1,4))),axis=0)
                        #print(transformmat)

                        # Load the STL files and add the vectors to the plot
                        your_mesh = mesh.Mesh.from_file('../paper_images/forkTipFromTurboSquid.stl')

                        # convert mm to m
                        your_mesh.vectors /= (1000 / 2)

                        your_mesh.transform(transformmat)

                        polycollection = mplot3d.art3d.Poly3DCollection(your_mesh.vectors)
                        polycollection.set_facecolor((colorval[0],colorval[1],colorval[2],0.4))
                        polycollection.set_edgecolors((0,0,0,0.05))
                        axes.add_collection3d(polycollection)

                scale = 0.12
                xmid = 0.4
                ymid = 0.70
                axes.set_xlim(xmid-scale,xmid+scale)
                axes.set_ylim(ymid-scale,ymid+scale)
                axes.set_zlim(0,2*scale)
                axes.set_xlabel("x")
                plotorigin = [xmid-radius-0.01,ymid]
                axiscolor = "gray"
                if sweepaxis == 0:
                    axes.quiver(plotorigin[0],plotorigin[1],center[2],0.02,0,0,arrow_length_ratio=0.1, color=axiscolor)
                    axes.quiver(plotorigin[0],plotorigin[1],center[2],0,0.02,0,arrow_length_ratio=0.1, color=axiscolor)
                    axes.quiver(plotorigin[0],plotorigin[1],center[2],0,0,0.02,arrow_length_ratio=0.1, color=axiscolor)
                    axes.text(plotorigin[0]+0.02,plotorigin[1],center[2],"x",fontsize=20,
                          horizontalalignment="center",verticalalignment="center")
                    axes.text(plotorigin[0],plotorigin[1]+0.02,center[2],"y",fontsize=20,
                          horizontalalignment="center",verticalalignment="center")
                for elev,azim in [(90,0),(10,30)]:
                        print("elev,azim: ",elev,azim)
                        axes.view_init(elev=elev, azim=azim)
                        #plt.savefig(f"3dplotfit_{name}_{modelname[7:]}_{trajid}_elev{elev}_azim{azim}_rto.png",
                        #           bbox_inches='tight')
                        axes.set_axis_off()
                        #display(figure);
                        if sweepaxis == 0 and elev != 90:
                            axes.text(plotorigin[0],plotorigin[1],center[2]+0.02,"z",fontsize=20,
                                horizontalalignment="center",verticalalignment="center")
                        figure.savefig(f"animatedLatent/latent{sweepaxis}-elev{elev}-azim{azim}-frame{framenumber}.png", bbox_inches='tight', pad_inches=0)
                plt.close()

In [6]:
timestrs = ["20230928-065206.641930"]
for timestr in timestrs:
    modelname = model_path_template.format(timestr=timestr)
    model = lm.LoadedModel(modelname)
    print(f"For model {timestr}:")
    [print(f"\t {k}: {v}") for (k,v) in model.modeldata.items()]

For model 20230928-065206.641930:
	 dtype_string: float
	 datafile: forkdata/forkTrajectoryData.npz
	 model_save_dir: results/retrainedforkdata/20230928-065206.641930/savedmodel
	 num_epochs: 10000
	 latent_dim: 3
	 device: cuda
	 dtype: torch.float32
	 traj_len: 200
	 traj_channels: 7
	 beta: 1.0
	 training_data_added_timing_noise: 0.1
	 pre_time_learning_epochs: 0
	 scalar_timewarping_lr: 0.0001
	 scalar_timewarping_eps: 1e-06
	 scalar_timewarper_timereg: 0.05
	 scalar_timewarper_endpointreg: 0
	 scaltw_min_canonical_time: 0.0
	 scaltw_max_canonical_time: 1.0
	 dec_use_softplus: False
	 dec_use_elu: True
	 dec_conv_use_elu: True
	 dec_template_use_custom_initialization: True
	 dec_template_custom_initialization_grad_t: 10.0
	 dec_template_custom_initialization_t_intercept_padding: 0.1
	 decoding_l2_weight_decay: 0.0
	 decoding_spatial_derivative_regularization: 0.0
	 dec_spatial_regularization_factor: 1.0
	 decoding_lr: 0.0001
	 encoding_lr: 0.0001
	 decoding_eps: 0.0001
	 encoding_e

In [7]:

has_plotted_train = False
for sweepaxis in range(1):
    for timestr in timestrs:
        modelname = model_path_template.format(timestr=timestr)
        plot_1d_sweep(modelname,sweepaxis)
        #axes.set_xlim(xmid-scale,xmid+scale)
        #axes.set_ylim(ymid-scale,ymid+scale)
        #axes.set_zlim(0,2*scale)
        #plt.show()

{'dtype_string': array('float', dtype='<U5'), 'datafile': array('forkdata/forkTrajectoryData.npz', dtype='<U31'), 'model_save_dir': array('results/retrainedforkdata/20230928-065206.641930/savedmodel',
      dtype='<U59'), 'num_epochs': array(10000), 'latent_dim': array(3), 'device': array('cuda', dtype='<U4'), 'dtype': torch.float32, 'traj_len': array(200), 'traj_channels': array(7), 'beta': array(1.), 'training_data_added_timing_noise': array(0.1), 'pre_time_learning_epochs': array(0), 'scalar_timewarping_lr': array(0.0001), 'scalar_timewarping_eps': array(1.e-06), 'scalar_timewarper_timereg': array(0.05), 'scalar_timewarper_endpointreg': array(0), 'scaltw_min_canonical_time': array(0.), 'scaltw_max_canonical_time': array(1.), 'dec_use_softplus': array(False), 'dec_use_elu': array(True), 'dec_conv_use_elu': array(True), 'dec_template_use_custom_initialization': array(True), 'dec_template_custom_initialization_grad_t': array(10.), 'dec_template_custom_initialization_t_intercept_padding

elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
elev,azim:  10 30
elev,azim:  90 0
ele