In [11]:
import pickle
from utils import Trajectory
from ipynb.fs.full.common import *

In [14]:
def train_nature_ae(proc_obs, latent_dim=64, batch_size = 512, max_epochs = 25, ckpt_path=""):
    import torch
    if ckpt_path=="":
        ckpt_path = f"ae_{latent_dim}_{max_epochs}"
    
    from torch.utils.data import Dataset, DataLoader, RandomSampler, TensorDataset
    
    data = torch.tensor(proc_obs).float()
    dataset = TensorDataset(data)
    sampler = RandomSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, batch_size = batch_size)
    
    ae = NatureAE(flat_dim=2560, latent_dim=latent_dim)
    trainer = pl.Trainer(gpus=0, max_epochs=max_epochs)
    trainer.fit(ae, dataloader)
    trainer.save_checkpoint(ckpt_path)
    
    return trainer

In [13]:
traj_data_path = "rts_rand_traj_100.pkl"
print("Loading Data...")

trajs = pickle.load(open(traj_data_path,"rb"))
traj_ids, obs_flat, gt_flat, proc_obs = load_data(trajs, normalize=False)


Loading Data...
1. Loaded Raw Obs:  13669 {(72, 96, 16)} [  0 255]
2. Preprocessed Obs:  13669 {(4, 72, 96)} [  0 255]


In [15]:
import pytorch_lightning as pl
from torch import nn
from torch.nn import functional as F
import torch
#https://analyticsindiamag.com/how-to-implement-convolutional-autoencoder-in-pytorch-with-cuda/
#https://debuggercafe.com/machine-learning-hands-on-convolutional-autoencoders/ 
class NatureAE(pl.LightningModule):
    def __init__(self, n_input_channels=4, flat_dim = 720, latent_dim=2, print_network=False, loss_func="cross_entropy"):
        super().__init__()

        self.save_hyperparameters()

        enc_out_dim = latent_dim
        self.latent_dim = latent_dim
        self.enc_out_dim = latent_dim
        self.loss_function = loss_func
        
        
        self.bla1 = nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=4, padding=0)
        self.bla2 = nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0)
        self.bla3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0)
        

        # encoder
        self.encoder1 = nn.Sequential(
            self.bla1,
            nn.ReLU(),
            self.bla2,
            nn.ReLU(),
            self.bla3,
            nn.ReLU(),
        )
        
        self.encoder2 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flat_dim, enc_out_dim),
            nn.ReLU()
        )
        
        if print_network:
            print("Encoder")
            print(self.encoder1)
            print(self.encoder2)

        # decoder

        self.decoder1 = nn.Sequential(
            nn.Linear(enc_out_dim, flat_dim),
            nn.ReLU(),
        )
        

        self.alb1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0)
        self.alb2 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=4, stride=2, padding=0, output_padding=1)
        self.alb3 = nn.ConvTranspose2d(in_channels=32, out_channels=n_input_channels, kernel_size=8, stride=4, padding=0, output_padding=0)
        
            
        self.decoder2 = nn.Sequential(
            self.alb1,
            nn.ReLU(),
            self.alb2,
            nn.ReLU(),
            self.alb3
        )
            
        if print_network:
            print("Decoder")
            print(self.decoder1)
            print(self.decoder2)
        

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
    
    
    def encode(self, x):
        
        """
        print("BLA START")
        print(x.shape)
        x = self.bla1(x)
        print(x.shape)
        x = self.bla2(x)
        print(x.shape)
        x = self.bla3(x)
        print(x.shape)
        print("BLA DONE")
        """
        
        
        x = self.encoder1(x)
        x = self.encoder2(x)
        
        
        return x
        
    
    def decode(self, x):
    
        """
        print("ALB Start")
        print(x.shape)
        x = self.decoder1(x)
        x = x.view(-1, 64, 5, 8)
        print(x.shape)
        x = self.alb1(x)
        print(x.shape)
        x = self.alb2(x)
        print(x.shape)
        x = self.alb3(x)
        print(x.shape)
        print("ALB Done")
        """
        x = self.decoder1(x)
        x = x.view(-1, 64, 5, 8)
        x = self.decoder2(x)
        
        return x
        

    def training_step(self, batch, batch_idx):

        x = batch[0]

        x_encoded = self.encode(x)
        #view
        x_hat = self.decode(x_encoded)
        
        #x_flat = x.view(x.shape[0], -1)
        #x_hat_flat = x_hat.view(x_hat.shape[0], -1)
        
        #print(x_flat.shape)
        #quit()
        

        #recon_loss = F.mse_loss(x, x_hat, reduction="sum")
        recon_loss = F.mse_loss(x, x_hat)
        
        #recon_loss = F.binary_cross_entropy_with_logits(x_hat_flat, x_flat, reduction="sum")
        #recon_loss = F.binary_cross_entropy_with_logits(x_hat, x)
        
        return recon_loss


    def embed(self, input):
        if isinstance(input, np.ndarray):
            input = torch.tensor(input).float()

        return self.encoder(input).detach().numpy()

In [23]:

batch_size = 256
max_epochs = 50
latent_dim = 2

data_dir = "models"
ckpt_path = f"{data_dir}/nature_ae_mse_no_scale_{latent_dim}_rts_100.ckpt"

print("Start Training")
train_nature_ae(proc_obs, latent_dim = latent_dim, max_epochs=max_epochs, ckpt_path=ckpt_path)

Start Training


GPU available: False, used: False
TPU available: False, using: 0 TPU cores

  | Name     | Type            | Params
---------------------------------------------
0 | bla1     | Conv2d          | 8.2 K 
1 | bla2     | Conv2d          | 32.8 K
2 | bla3     | Conv2d          | 36.9 K
3 | encoder1 | Sequential      | 78.0 K
4 | encoder2 | Sequential      | 5.1 K 
5 | decoder1 | Sequential      | 7.7 K 
6 | alb1     | ConvTranspose2d | 36.9 K
7 | alb2     | ConvTranspose2d | 32.8 K
8 | alb3     | ConvTranspose2d | 8.2 K 
9 | decoder2 | Sequential      | 77.9 K
---------------------------------------------
168 K     Trainable params
0         Non-trainable params
168 K     Total params
0.675     Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

<pytorch_lightning.trainer.trainer.Trainer at 0x12e8fd100>