##### Set-up

Importing and setting up the code.  If we are in Colab then this will include cloning the project from GitHub and downloading the datasets.  Otherwise, if running the notebook locally, we just check that the current working directory is what we expect.

In [None]:
import os
import h5py

try:
    from google.colab import drive
    IN_COLLAB=True
    
    try:
        try:
            from src.representations import Representation
        except:
            os.chdir("dynamical-disentanglement")
            !git pull
    except:
        print("-----Cloning from source---\n\n")
        !git clone https://github.com/tomdbar/dynamical-disentanglement.git        
        os.chdir("dynamical-disentanglement")

    try:
        h5py.File("./_data/3dshapes.h5", 'r')["images"]
    except:
        print("\n\n-----Downloading datasets---\n\n")
        !scripts/download_3dcars.sh
        !scripts/download_3dshapes.sh

except ModuleNotFoundError:
    IN_COLLAB=False
    
    i = 0
    while ('_data' not in os.listdir()) and (i < 5):
        os.chdir("../")
        i += 1
        
if IN_COLLAB:
    print("\n\nNotebook hosted on Google Colab.")
else:
    print("Notebook hosted in local environment.")
print("Set cwd to :", os.getcwd())

Import the required functions.

Yes, this really is all we need as ``DynamicRepresentationLearner`` wraps-up everything and does the heavy lifting.  Of course, to extend the work you probably want access to the inner-workings in more detail, so the other notebook (``introduction.ipynb``) provides a more detailed introduction in this regard.

In [None]:
from src.learners import DynamicRepresentationLearner, LossTarget
from src.factorised_datasets import GridWorld, Cars3D, Shapes3D

##### Train

Create the desired dataset, pass it to ``DynamicRepresentationLearner`` with arguments detailing the latent space and desired entanglement regularisation and watch it go...

In [None]:
# dataset = Shapes3D([5, 5, 1, 1, 1, 1])
# dataset = Cars3D([4,4,1])
dataset = GridWorld([5,5,[0],[0],[4]])

In [None]:
if len(dataset.dataset)< 30:
    dataset.imshow();

In [None]:
rep_learner = DynamicRepresentationLearner(dataset,
                                           latent_dim=4,
                                           episode_length = 20,
                                           num_parallel_episodes = 3,
                                           max_action_magnitude=1,
                                           
                                           lr_enc=5e-3,
                                           lr_dec=5e-3,
                                           lr_rep=5e-3,
                                           
                                           ent_loss_weight = 0,
                                           final_ent_loss_weight = 1e-5,
                                           final_ent_loss_weight_iter=5000,
                                           ent_loss_target=LossTarget.MIN,
                                           random_resets=False,

                                           save_loc="_results/gridworld/")

In [None]:
rep_learner.train(num_sgd_steps=100,log_freq=10)
rep_learner.save()

In [None]:
rep_learner.plot_training(save=False);

##### Test

Use the built in helper functions to test the learnt representations.

In [None]:
rep_learner.test(episode_length=20,
                 num_episodes=10,
                 save_scores=True, # Save pd.Dataframe of scores from each episode.
                 save_plot=True, # Save plot of scores from each episode.
                 show_imgs=True, # Show ground truth and reconstructed images (from only a single episode).
                 save_imgs=True) # Save ground truth and reconstructed images (from all episodes)

In [None]:
rep_learner.plot_representations(save=True, num_highlight=1);