In [26]:
import pickle
from utils import Trajectory





In [27]:

rts_rand_traj = pickle.load(open("rts_rand_traj_100","rb"))
print(len(rts_rand_traj))

100


In [28]:
rts_rand_traj_5 = pickle.load(open("rts_rand_traj_5","rb"))
print(len(rts_rand_traj_5))

5


In [4]:
from typing import List
def extract_data_from_trajs(trajs:List[Trajectory]):
    
    traj_ids, obs_flat, gt_flat = [], [], []
    
    for traj_id, traj in enumerate(trajs):
        
        obs_flat.extend(traj.obs_arr)
        gt_flat.extend(traj.gt_arr)
        traj_ids.extend([traj_id]*len(traj.obs_arr))
        
    
    return traj_ids, obs_flat, gt_flat


traj_ids, obs_flat, gt_flat = extract_data_from_trajs(rts_rand_traj_5)

In [29]:
def preprocess_obs(obs_arr):
    
    mod_arr = []
    
    for obs in obs_arr:
        obs = obs[:,:,-4:]
        obs = np.moveaxis(obs, [0,1], [1,2])
        obs = obs/255.0
        mod_arr.append(obs)
    
    return mod_arr

prep_obs = preprocess_obs(obs_flat)
print("Original: ", {obs.shape for obs in obs_flat}, " Unique value: ", np.unique(obs_flat))
print("Preprocessed: ", {obs.shape for obs in prep_obs}, " Unique value: ", np.unique(prep_obs))
print("Total Obs: ", len(prep_obs))

Original:  {(72, 96, 16)}  Unique value:  [  0 255]
Preprocessed:  {(4, 72, 96)}  Unique value:  [0. 1.]
Total Obs:  836


array([0., 1.])

In [32]:
import pytorch_lightning as pl
from torch import nn

In [108]:
#https://analyticsindiamag.com/how-to-implement-convolutional-autoencoder-in-pytorch-with-cuda/
#https://debuggercafe.com/machine-learning-hands-on-convolutional-autoencoders/ 
class NatureAE(pl.LightningModule):
    def __init__(self, n_input_channels=4, flat_dim = 720, latent_dim=2):
        super().__init__()

        self.save_hyperparameters()

        enc_out_dim = latent_dim
        self.latent_dim = latent_dim
        self.enc_out_dim = latent_dim
                
        
        
        self.bla1 = nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0)
        self.bla2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0)
        self.bla3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)
        

        # encoder
        self.encoder1 = nn.Sequential(
            self.bla1,
            nn.ReLU(),
            self.bla2,
            nn.ReLU(),
            self.bla3,
            nn.ReLU(),
        )
        
        self.encoder2 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flat_dim, enc_out_dim),
            nn.ReLU()
        )

        print("Encoder")
        print(self.encoder1)
        print(self.encoder2)

        # decoder

        self.decoder1 = nn.Sequential(
            nn.Linear(enc_out_dim, flat_dim),
            nn.ReLU(),
        )
        

        self.alb1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0)
        self.alb2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=0, output_padding=1)
        self.alb3 = nn.ConvTranspose2d(in_channels=32, out_channels=n_input_channels, kernel_size=8, stride=4, padding=0, output_padding=0)
        
            
        self.decoder2 = nn.Sequential(
            self.alb1,
            nn.ReLU(),
            self.alb2,
            nn.ReLU()
        )

        print("Decoder")
        print(self.decoder1)
        print(self.decoder2)
        

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
    
    
    def encode(self, x):
        
        
        print("BLA START")
        print(x.shape)
        x = self.bla1(x)
        print(x.shape)
        x = self.bla2(x)
        print(x.shape)
        x = self.bla3(x)
        print(x.shape)
        print("BLA DONE")
        
        x = self.encoder2(x)
        """
        x = self.encoder1(x)
        
        
        """
        return x
        
    
    def decode(self, x):

        
        print("ALB Start")
        print(x.shape)
        x = self.decoder1(x)
        x = x.view(-1, 64, 5, 8)
        print(x.shape)
        x = self.alb1(x)
        print(x.shape)
        x = self.alb2(x)
        print(x.shape)
        x = self.alb3(x)
        print(x.shape)
        print("ALB Done")
        return x
        

    def training_step(self, batch, batch_idx):

        x = batch[0]

        x_encoded = self.encoder(x)
        #view
        x_hat = self.decoder(x_encoded)

        recon_loss = F.mse_loss(x, x_hat)
        return recon_loss


    def embed(self, input):
        if isinstance(input, np.ndarray):
            input = torch.tensor(input).float()

        return self.encoder(input).detach().numpy()

In [109]:
import torch 

ae = NatureAE(flat_dim=2560)

obs_in = np.array(prep_obs)
print("Data Shape", obs_in.shape)


data = torch.tensor(obs_in).float()

encoded = ae.encode(data)

print(encoded.shape)

Encoder
Sequential(
  (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
  (1): ReLU()
  (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
  (3): ReLU()
  (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (5): ReLU()
)
Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=2560, out_features=2, bias=True)
  (2): ReLU()
)
Decoder
Sequential(
  (0): Linear(in_features=2, out_features=2560, bias=True)
  (1): ReLU()
)
Sequential(
  (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), output_padding=(1, 1))
  (3): ReLU()
)
Data Shape (836, 4, 72, 96)
BLA START
torch.Size([836, 4, 72, 96])
torch.Size([836, 32, 17, 23])
torch.Size([836, 64, 7, 10])
torch.Size([836, 64, 5, 8])
BLA DONE
torch.Size([836, 2])


In [110]:
decoded = ae.decode(encoded)

ALB Start
torch.Size([836, 2])
torch.Size([836, 64, 5, 8])
torch.Size([836, 64, 7, 10])
torch.Size([836, 32, 17, 23])
torch.Size([836, 4, 72, 96])
ALB Done


In [111]:
decoded.shape

torch.Size([836, 4, 72, 96])