In [1]:
import pickle
from utils import Trajectory


In [2]:
from ipynb.fs.full.common import *

In [3]:
def load_data(trajectory_data):
    import numpy as np
    
    
    traj_ids, obs_flat, gt_flat = extract_data_from_trajs(trajectory_data)
    print("1. Loaded Raw Obs: ", len(obs_flat), {obs.shape for obs in obs_flat}, np.unique(obs_flat))

    proc_obs = preprocess_obs(obs_flat)
    print("2. Preprocessed Obs: ", len(proc_obs), {obs.shape for obs in proc_obs}, np.unique(proc_obs))
    
    return traj_ids, obs_flat, gt_flat, proc_obs
    

In [4]:
def get_embedding(proc_obs, ae):
    import torch
    import numpy as np
    data = torch.tensor(proc_obs).float()
    
    encoded = ae.encode(data)

    print("1. Encoded Data: ",encoded.shape)
    encoded_np = encoded.detach().numpy()
    print("1. Encoded Data max, min, mean", np.max(encoded_np), np.min(encoded_np), np.mean(encoded_np))
    decoded = ae.decode(encoded)
    print("2. Recons Data: ",decoded.shape)

    flatten_obs = np.array(proc_obs)
    flatten_obs = flatten_obs.reshape(-1, flatten_obs.shape[1]*flatten_obs.shape[2]*flatten_obs.shape[3])
    print("3. Flatten", flatten_obs.shape)

    
    return encoded, encoded_np, decoded, flatten_obs 

In [5]:
def test_embedding_quality(trajectory_data, ae, ks = [5, 9], leaf_size=400):
    
    print("\t...Loading Data")
    traj_ids, obs_flat, gt_flat, proc_obs = load_data(trajectory_data)
    print("\t...Generating Embedding")
    encoded, encoded_np, decoded, flatten_obs = get_embedding(proc_obs, ae)
    
    scores = {}
    for K in ks:
        match_score, _, _, _, _ = neighborhood_comparison(K, flatten_obs, encoded_np, leaf_size=leaf_size)
        print(f"K: {K} Match Score: {match_score:2f}")
        scores[K]=match_score
        
    return scores

In [8]:
random_ae = NatureAE(flat_dim=2560)
pret_rte_ae_64_1000 = NatureAE.load_from_checkpoint(checkpoint_path="model_data/nature_ae_64_rts_1000.ckpt")

In [6]:
rts_rand_traj_5 = pickle.load(open("rts_rand_traj_5.pkl","rb"))
rts_rand_traj_100 = pickle.load(open("rts_rand_traj_100.pkl","rb"))

In [8]:
print("Random RTS 5  + Random AE")
test_embedding_quality(rts_rand_traj_5, random_ae)

Random RTS 5  + Random AE
	...Loading Data
1. Loaded Raw Obs:  631 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  631 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([631, 2])
1. Encoded Data max, min, mean 0.028704867 0.019616608 0.023952829
2. Recons Data:  torch.Size([631, 4, 72, 96])
3. Flatten (631, 27648)
K: 5 Match Score: 0.329635
K: 9 Match Score: 0.637084


{5: 0.329635499207607, 9: 0.6370839936608558}

In [9]:
print("Random RTS 5 + pret_rte_ae_64_1000")
test_embedding_quality(rts_rand_traj_5, pret_rte_ae_64_1000)

Random RTS 5 + pret_rte_ae_64_1000
	...Loading Data
1. Loaded Raw Obs:  631 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  631 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([631, 64])
1. Encoded Data max, min, mean 16.047356 0.0 2.247074
2. Recons Data:  torch.Size([631, 4, 72, 96])
3. Flatten (631, 27648)
K: 5 Match Score: 3.044374
K: 9 Match Score: 5.232964


{5: 3.0443740095087164, 9: 5.23296354992076}

In [10]:
print("Random RTS 10  + pret_rte_ae_64_1000")

rts_rand_traj_10 = rts_rand_traj_100[0:10]

test_embedding_quality(rts_rand_traj_10, pret_rte_ae_64_1000)

Random RTS 10  + pret_rte_ae_64_1000
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 64])
1. Encoded Data max, min, mean 15.73028 0.0 2.2597044
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 3.024618
K: 9 Match Score: 4.889643


{5: 3.024617996604414, 9: 4.889643463497453}

In [7]:
pns_rand_100 = pickle.load(open("data/pass_n_shoot_rand_traj_100.pkl","rb"))

In [8]:
easy_counter_rand_10 = pickle.load(open("data/easy_counter_rand_traj_10.pkl","rb"))

In [9]:
_3v1_rand_10 = pickle.load(open("data/3v1_rand_traj_10.pkl","rb"))

In [13]:
print("Random Pass and Shoot 10  + random")
test_embedding_quality(pns_rand_100[0:10], random_ae)

Random Pass and Shoot 10  + random
	...Loading Data
1. Loaded Raw Obs:  1855 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1855 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1855, 2])
1. Encoded Data max, min, mean 0.028563423 0.018904222 0.024005113
2. Recons Data:  torch.Size([1855, 4, 72, 96])
3. Flatten (1855, 27648)
K: 5 Match Score: 0.263073
K: 9 Match Score: 0.449596


{5: 0.26307277628032344, 9: 0.4495956873315364}

In [9]:
print("Random Easy Cuunter 10  + pret_rte_ae_64_1000")
test_embedding_quality(easy_counter_rand_10, pret_rte_ae_64_1000)

Random Easy Cuunter 10  + pret_rte_ae_64_1000
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 64])
1. Encoded Data max, min, mean 14.413122 0.0 2.5350876
2. Recons Data:  torch.Size([693, 4, 72, 96])
3. Flatten (693, 27648)
K: 5 Match Score: 2.927850
K: 9 Match Score: 5.053391


{5: 2.9278499278499277, 9: 5.053391053391054}

In [12]:
print("Random 3v1 10  + pret_rte_ae_64_1000")
test_embedding_quality(_3v1_rand_10, pret_rte_ae_64_1000)

Random 3v1 10  + pret_rte_ae_64_1000
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 64])
1. Encoded Data max, min, mean 10.874697 0.0 2.1587253
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 1.919094
K: 9 Match Score: 3.245955


{5: 1.919093851132686, 9: 3.2459546925566345}

In [14]:
print("Random Pass and Shoot 10  + pret_rte_ae_64_1000")
test_embedding_quality(pns_rand_100[0:10], pret_rte_ae_64_1000)

Random Pass and Shoot 10  + pret_rte_ae_64_1000
	...Loading Data
1. Loaded Raw Obs:  1855 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1855 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1855, 64])
1. Encoded Data max, min, mean 12.200652 0.0 2.145445
2. Recons Data:  torch.Size([1855, 4, 72, 96])
3. Flatten (1855, 27648)
K: 5 Match Score: 2.115903
K: 9 Match Score: 3.554178


{5: 2.115902964959569, 9: 3.554177897574124}

In [14]:
rte_ae_64_100 = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_64_rts_100.ckpt")

In [16]:
print("Random RTS 10  + rte_ae_64_100")
test_embedding_quality(rts_rand_traj_100[0:10], rte_ae_64_100)

Random RTS 10  + rte_ae_64_100
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 64])
1. Encoded Data max, min, mean 12.36917 0.0 0.24070898
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 2.216469
K: 9 Match Score: 3.803056


{5: 2.2164685908319184, 9: 3.803056027164686}

In [15]:
print("Random Easy Cuunter 10  + rte_ae_64_100")
test_embedding_quality(easy_counter_rand_10, rte_ae_64_100)

Random Easy Cuunter 10  + rte_ae_64_100
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 64])
1. Encoded Data max, min, mean 6.5344625 0.0 0.39137703
2. Recons Data:  torch.Size([693, 4, 72, 96])
3. Flatten (693, 27648)
K: 5 Match Score: 2.088023
K: 9 Match Score: 3.910534


{5: 2.088023088023088, 9: 3.9105339105339105}

In [16]:
print("Random 3v1 10  + rte_ae_64_100")
test_embedding_quality(_3v1_rand_10, rte_ae_64_100)

Random 3v1 10  + rte_ae_64_100
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 64])
1. Encoded Data max, min, mean 3.0973241 0.0 0.1366758
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 1.728155
K: 9 Match Score: 3.019417


{5: 1.7281553398058251, 9: 3.0194174757281553}

In [19]:
print("Random Pass and Shoot 10  + rte_ae_64_100")
test_embedding_quality(pns_rand_100[0:10], rte_ae_64_100)

Random Pass and Shoot 10  + rte_ae_64_100
	...Loading Data
1. Loaded Raw Obs:  1855 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1855 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1855, 64])
1. Encoded Data max, min, mean 2.9609768 0.0 0.14203393
2. Recons Data:  torch.Size([1855, 4, 72, 96])
3. Flatten (1855, 27648)
K: 5 Match Score: 1.969811
K: 9 Match Score: 3.371429


{5: 1.969811320754717, 9: 3.3714285714285714}

In [31]:
test_ds = { "Run To Score" : rts_rand_traj_100[0:10],
            "3 vs 1": _3v1_rand_10,
            "Easy Counter": easy_counter_rand_10,
            "Pass and Shoot": pns_rand_100[0:10]}

random_ae = NatureAE(flat_dim=2560, latent_dim=2048)
scores = {}
for name, ds in test_ds.items():
    print("Testing: ", name)
    s = test_embedding_quality(ds, random_ae , ks = [5], leaf_size=400)
    scores[name]=s

    print()
    
print(scores)

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 2048])
1. Encoded Data max, min, mean 0.050157707 0.0 0.00628261
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 2.844652

Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 2048])
1. Encoded Data max, min, mean 0.04958695 0.0 0.0062758634
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 2.565264

Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 2048])
1. Encoded Data max, min, mean 0.050560053 0.0 0.0062970524
2. Recons Data: 

In [23]:
test_ds = { "Run To Score" : rts_rand_traj_100[0:10],
            "3 vs 1": _3v1_rand_10,
            "Easy Counter": easy_counter_rand_10,
            "Pass and Shoot": pns_rand_100[0:10]}

random_ae = NatureAE(flat_dim=2560, latent_dim=8)
scores = {}
for name, ds in test_ds.items():
    print("Testing: ", name)
    s = test_embedding_quality(ds, random_ae , ks = [5], leaf_size=400)
    scores[name]=s

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 8])
1. Encoded Data max, min, mean 0.027482502 0.0 0.0083416095
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 0.269100
Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 8])
1. Encoded Data max, min, mean 0.027228521 0.0 0.00834634
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 0.477886
Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 8])
1. Encoded Data max, min, mean 0.028454682 0.0 0.008402378
2. Recons Data:  torch.Size

In [24]:
scores

{'Run To Score': {5: 0.2691001697792869},
 '3 vs 1': {5: 0.4778856526429342},
 'Easy Counter': {5: 0.3448773448773449},
 'Pass and Shoot': {5: 0.4064690026954178}}

In [25]:
test_ds = { "Run To Score" : rts_rand_traj_100[0:10],
            "3 vs 1": _3v1_rand_10,
            "Easy Counter": easy_counter_rand_10,
            "Pass and Shoot": pns_rand_100[0:10]}

random_ae = NatureAE(flat_dim=2560, latent_dim=2)
scores = {}
for name, ds in test_ds.items():
    print("Testing: ", name)
    s = test_embedding_quality(ds, random_ae , ks = [5], leaf_size=400)
    scores[name]=s

    print()
    
print(scores)

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 2])
1. Encoded Data max, min, mean 0.0019134113 0.0 9.22515e-05
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 0.047538

Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 2])
1. Encoded Data max, min, mean 0.0016815114 0.0 5.0206858e-05
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 0.047465

Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 2])
1. Encoded Data max, min, mean 0.0036977953 0.0 0.000219903
2. Recons Data:  tor

In [20]:
rts_ae_8_1000 = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_8_rts_1000.ckpt")

In [25]:
test_ds = { "Run To Score" : rts_rand_traj_100[0:10],
            "3 vs 1": _3v1_rand_10,
            "Easy Counter": easy_counter_rand_10,
            "Pass and Shoot": pns_rand_100[0:10]}


scores = {}
for name, ds in test_ds.items():
    print("Testing: ", name)
    s = test_embedding_quality(ds, rts_ae_8_1000, ks = [5, 9], leaf_size=400)
    scores[name]=s
    


Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 8])
1. Encoded Data max, min, mean 0.0 0.0 0.0
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: -0.943124
K: 9 Match Score: -0.882852
Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 8])
1. Encoded Data max, min, mean 0.0 0.0 0.0
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: -0.925566
K: 9 Match Score: -0.837109
Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 8])
1. Encoded Data max, min, mean 0.0 0.0 0.0
2. Recons Data:  

In [26]:
scores

{'Run To Score': {5: -0.9431239388794567, 9: -0.8828522920203735},
 '3 vs 1': {5: -0.9255663430420712, 9: -0.837108953613808},
 'Easy Counter': {5: -0.924963924963925, 9: -0.8210678210678211},
 'Pass and Shoot': {5: -0.9611859838274932, 9: -0.9266846361185984}}

In [29]:
rts_ae_8_100 = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_8_rts_100.ckpt")
test_ds = { "Run To Score" : rts_rand_traj_100[0:10],
            "3 vs 1": _3v1_rand_10,
            "Easy Counter": easy_counter_rand_10,
            "Pass and Shoot": pns_rand_100[0:10]}

scores = {}
for name, ds in test_ds.items():
    print("Testing: ", name)
    s = test_embedding_quality(ds, rts_ae_8_100, ks = [5], leaf_size=400)
    scores[name]=s

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 8])
1. Encoded Data max, min, mean 0.0 0.0 0.0
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 0.008489
Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 8])
1. Encoded Data max, min, mean 0.0 0.0 0.0
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 0.020496
Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 8])
1. Encoded Data max, min, mean 0.0 0.0 0.0
2. Recons Data:  torch.Size([693, 4, 72, 96])
3. Flatten (693, 27648)
K: 5 

KeyboardInterrupt: 

In [14]:
ae = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_sse_8_rts_100.ckpt")
test_ds = { "Run To Score" : rts_rand_traj_100[0:10]}

scores = {}
for name, ds in test_ds.items():
    print("Testing: ", name)
    s = test_embedding_quality(ds, ae, ks = [5], leaf_size=400)
    scores[name]=s

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 8])
1. Encoded Data max, min, mean 0.46478882 0.0 0.014321777
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 0.222411


In [11]:
ae = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_bcel_64_rts_100.ckpt")
test_ds = { "Run To Score" : rts_rand_traj_100[0:10],
            "3 vs 1": _3v1_rand_10,
            "Easy Counter": easy_counter_rand_10,
            "Pass and Shoot": pns_rand_100[0:10]}

scores = {}
for name, ds in test_ds.items():
    print("Testing: ", name)
    s = test_embedding_quality(ds, ae, ks = [5], leaf_size=400)
    scores[name]=s
    
#prev training: 0.75

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 64])
1. Encoded Data max, min, mean 30.621178 0.0 2.3540118
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 0.876061
Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 64])
1. Encoded Data max, min, mean 27.043846 0.0 2.3255394
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 0.908306
Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 64])
1. Encoded Data max, min, mean 31.735456 0.0 2.5649571
2. Recons Data:  torch.Size([693, 4,

In [16]:
ae = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_bce_8_rts_100.ckpt")
test_ds = { "Run To Score" : rts_rand_traj_100[0:10]}

scores = {}
for name, ds in test_ds.items():
    print("Testing: ", name)
    s = test_embedding_quality(ds, ae, ks = [5], leaf_size=400)
    scores[name]=s

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 8])
1. Encoded Data max, min, mean 66931772.0 0.0 16020477.0
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 0.724109


In [10]:
print("\t...Loading Data")
traj_ids, obs_flat, gt_flat, proc_obs = load_data(easy_counter_rand_10)
print("\t...Generating Embedding")
encoded, encoded_np, decoded, flatten_obs = get_embedding(proc_obs, rts_ae_8_100)



	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [0. 1.]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 8])
1. Encoded Data max, min, mean 0.0 0.0 0.0
2. Recons Data:  torch.Size([693, 4, 72, 96])
3. Flatten (693, 27648)


In [12]:
decoded_arr = decoded.detach().numpy()

In [13]:
decoded_arr.mean(), decoded_arr.max(), decoded_arr.min()

(0.00049331127, 0.021051798, -0.014660655)

In [None]:
s, matches, intersects, ind_st_arr, ind_embd_arr = neighborhood_comparison(5, flatten_obs, encoded_np)

In [33]:
ind_st_arr

array([[132, 438,  51,   0, 176, 349],
       [  1,  52, 604, 603, 664, 133],
       [  2,   3,  53,  54, 605, 440],
       ...,
       [690, 689, 691, 688, 692, 171],
       [691, 692, 690, 203, 202, 161],
       [692, 691, 204, 690, 169, 227]])

In [34]:
ind_embd_arr

array([[692,   1, 344,  82, 170,   0],
       [692,   1, 344,  82, 170,   0],
       [692,   1, 344,  82, 170,   0],
       ...,
       [692,   1, 344,  82, 170,   0],
       [692,   1, 344,  82, 170,   0],
       [692,   1, 344,  82, 170,   0]])

In [38]:
np.max(encoded_np)

0.0