In [1]:
import pickle
from utils import Trajectory


In [2]:
from ipynb.fs.full.common import *

In [3]:
def get_embedding(proc_obs, ae):
    import torch
    import numpy as np
    data = torch.tensor(proc_obs).float()
    
    encoded = ae.encode(data)

    print("1. Encoded Data: ",encoded.shape)
    encoded_np = encoded.detach().numpy()
    print("1. Encoded Data max, min, mean", np.max(encoded_np), np.min(encoded_np), np.mean(encoded_np))
    decoded = ae.decode(encoded)
    print("2. Recons Data: ",decoded.shape)

    flatten_obs = np.array(proc_obs)
    flatten_obs = flatten_obs.reshape(-1, flatten_obs.shape[1]*flatten_obs.shape[2]*flatten_obs.shape[3])
    print("3. Flatten", flatten_obs.shape)

    
    return encoded, encoded_np, decoded, flatten_obs 

In [4]:
def test_embedding_quality(trajectory_data, ae, ks = [5, 9], leaf_size=400, normalize = True):
    
    print("\t...Loading Data")
    traj_ids, obs_flat, gt_flat, proc_obs = load_data(trajectory_data, normalize= normalize)
    print("\t...Generating Embedding")
    encoded, encoded_np, decoded, flatten_obs = get_embedding(proc_obs, ae)
    
    scores = {}
    for K in ks:
        match_score, _, _, _, _ = neighborhood_comparison(K, flatten_obs, encoded_np, leaf_size=leaf_size)
        print(f"K: {K} Match Score: {match_score:2f}")
        scores[K]=match_score
        
    return scores

In [5]:
rts_rand_traj_100 = pickle.load(open("rts_rand_traj_100.pkl","rb"))
rts_rand_traj_10 = rts_rand_traj_100[0:10]

In [6]:
pns_rand_100 = pickle.load(open("data/pass_n_shoot_rand_traj_100.pkl","rb"))

In [7]:
easy_counter_rand_10 = pickle.load(open("data/easy_counter_rand_traj_10.pkl","rb"))

In [8]:
_3v1_rand_10 = pickle.load(open("data/3v1_rand_traj_10.pkl","rb"))

In [9]:
def test_all(ae, normalize = True):

    test_ds = { "Run To Score" : rts_rand_traj_100[0:10],
                "3 vs 1": _3v1_rand_10,
                "Easy Counter": easy_counter_rand_10,
                "Pass and Shoot": pns_rand_100[0:10]}

    
    scores = {}
    for name, ds in test_ds.items():
        print("Testing: ", name)
        s = test_embedding_quality(ds, ae , ks = [5], leaf_size=400, normalize=normalize)
        scores[name]=s

        print()

    print(scores)

In [11]:
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
import torch
#https://analyticsindiamag.com/how-to-implement-convolutional-autoencoder-in-pytorch-with-cuda/
#https://debuggercafe.com/machine-learning-hands-on-convolutional-autoencoders/ 
class NatureAE(pl.LightningModule):
    def __init__(self, n_input_channels=4, flat_dim = 720, latent_dim=2, print_network=False, loss_func="cross_entropy"):
        super().__init__()

        self.save_hyperparameters()

        enc_out_dim = latent_dim
        self.latent_dim = latent_dim
        self.enc_out_dim = latent_dim
        self.loss_function = loss_func
        
        
        self.bla1 = nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0)
        self.bla2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0)
        self.bla3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)
        

        # encoder
        self.encoder1 = nn.Sequential(
            self.bla1,
            nn.ReLU(),
            self.bla2,
            nn.ReLU(),
            self.bla3,
            nn.ReLU(),
        )
        
        self.encoder2 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flat_dim, enc_out_dim),
            nn.ReLU()
        )
        
        if print_network:
            print("Encoder")
            print(self.encoder1)
            print(self.encoder2)

        # decoder

        self.decoder1 = nn.Sequential(
            nn.Linear(enc_out_dim, flat_dim),
            nn.ReLU(),
        )
        

        self.alb1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0)
        self.alb2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=0, output_padding=1)
        self.alb3 = nn.ConvTranspose2d(in_channels=32, out_channels=n_input_channels, kernel_size=8, stride=4, padding=0, output_padding=0)
        
            
        self.decoder2 = nn.Sequential(
            self.alb1,
            nn.ReLU(),
            self.alb2,
            nn.ReLU(),
            self.alb3
        )
            
        if print_network:
            print("Decoder")
            print(self.decoder1)
            print(self.decoder2)
        

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
    
    
    def encode(self, x):
        
        """
        print("BLA START")
        print(x.shape)
        x = self.bla1(x)
        print(x.shape)
        x = self.bla2(x)
        print(x.shape)
        x = self.bla3(x)
        print(x.shape)
        print("BLA DONE")
        """
        
        
        x = self.encoder1(x)
        x = self.encoder2(x)
        
        
        return x
        
    
    def decode(self, x):
    
        """
        print("ALB Start")
        print(x.shape)
        x = self.decoder1(x)
        x = x.view(-1, 64, 5, 8)
        print(x.shape)
        x = self.alb1(x)
        print(x.shape)
        x = self.alb2(x)
        print(x.shape)
        x = self.alb3(x)
        print(x.shape)
        print("ALB Done")
        """
        x = self.decoder1(x)
        x = x.view(-1, 64, 5, 8)
        x = self.decoder2(x)
        
        return x
        

    def training_step(self, batch, batch_idx):

        x = batch[0]

        x_encoded = self.encode(x)
        #view
        x_hat = self.decode(x_encoded)
        
        #x_flat = x.view(x.shape[0], -1)
        #x_hat_flat = x_hat.view(x_hat.shape[0], -1)
        
        #print(x_flat.shape)
        #quit()
        

        #recon_loss = F.mse_loss(x, x_hat, reduction="sum")
        recon_loss = F.mse_loss(x, x_hat)
        
        #recon_loss = F.binary_cross_entropy_with_logits(x_hat_flat, x_flat, reduction="sum")
        #recon_loss = F.binary_cross_entropy_with_logits(x_hat, x)
        
        return recon_loss


    def embed(self, input):
        if isinstance(input, np.ndarray):
            input = torch.tensor(input).float()

        return self.encoder(input).detach().numpy()

In [15]:
ae = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_mse_no_scale_64_rts_100.ckpt")
test_all(ae, normalize = False)

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 64])
1. Encoded Data max, min, mean 4997.617 0.0 740.1769
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 3.105263

Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 64])
1. Encoded Data max, min, mean 2875.4465 0.0 399.17087
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 2.367853

Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 64])
1. Encoded Data max, min, mean 7567.621 0.0 1232.5936
2. Recons Data:  torch.Size([69

In [None]:
ae = NatureAE.load_from_checkpoint(checkpoint_path="model_data/nature_ae_64_rts_1000.ckpt")
test_all(ae)
#rts_10->3.02

In [None]:
ae = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_64_rts_100.ckpt")
test_all(ae)
#rts_10->2.23 

In [16]:
ae = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_mse_no_scale_8_rts_100.ckpt")
test_all(ae, normalize = False)

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 8])
1. Encoded Data max, min, mean 4452.1206 0.0 870.63885
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 2.363328

Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 8])
1. Encoded Data max, min, mean 3538.4011 0.0 492.76065
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 1.523193

Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 8])
1. Encoded Data max, min, mean 5405.8374 0.0 1382.5745
2. Recons Data:  torch.Size([69

In [18]:
ae = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_mse_no_scale_3_rts_100.ckpt")
test_all(ae, normalize = False)

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 3])
1. Encoded Data max, min, mean 4703.8867 0.0 437.039
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 0.804754

Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 3])
1. Encoded Data max, min, mean 5254.8716 0.0 425.45425
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 0.567422

Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 3])
1. Encoded Data max, min, mean 3305.9016 0.0 178.86223
2. Recons Data:  torch.Size([693,

In [20]:
ae = NatureAE.load_from_checkpoint(checkpoint_path="models/nature_ae_mse_no_scale_2_rts_100.ckpt")
test_all(ae, normalize = False)

Testing:  Run To Score
	...Loading Data
1. Loaded Raw Obs:  1178 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  1178 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([1178, 2])
1. Encoded Data max, min, mean 4415.8477 0.0 452.01953
2. Recons Data:  torch.Size([1178, 4, 72, 96])
3. Flatten (1178, 27648)
K: 5 Match Score: 0.727504

Testing:  3 vs 1
	...Loading Data
1. Loaded Raw Obs:  927 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  927 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([927, 2])
1. Encoded Data max, min, mean 3249.7612 0.0 256.41666
2. Recons Data:  torch.Size([927, 4, 72, 96])
3. Flatten (927, 27648)
K: 5 Match Score: 0.338727

Testing:  Easy Counter
	...Loading Data
1. Loaded Raw Obs:  693 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  693 {(4, 72, 96)} [  0 255]
	...Generating Embedding
1. Encoded Data:  torch.Size([693, 2])
1. Encoded Data max, min, mean 2204.3794 0.0 62.961834
2. Recons Data:  torch.Size([69