In [4]:
from IPython.display import HTML, Javascript, clear_output
import os
import sys

from google.colab import drive

drive.mount("/content/drive")

!pip install human_id
!pip install mlflow

# Define the repo URL and folder name
REPO_URL = "https://github.com/rosinusserrano/autoencoding_experiments.git"
LOCAL_DIR = "/content/autoencoding_experiments"

# Inject a floating button with JavaScript
def inject_floating_button():
    display(HTML("""
    <script>
    function reloadRepo() {
        google.colab.kernel.invokeFunction('notebook.reload_repo', [], {});
    }
    const button = document.createElement('button');
    button.innerHTML = 'Reload Repo';
    button.style.position = 'fixed';
    button.style.bottom = '20px';
    button.style.right = '20px';
    button.style.backgroundColor = '#ff5050';
    button.style.color = 'white';
    button.style.border = 'none';
    button.style.padding = '10px 20px';
    button.style.borderRadius = '5px';
    button.style.boxShadow = '0px 4px 6px rgba(0,0,0,0.1)';
    button.style.cursor = 'pointer';
    button.onclick = reloadRepo;
    document.body.appendChild(button);
    </script>
    """))

def clone_and_import():
    # Clear output and display new button
    clear_output(wait=True)
    inject_floating_button()
    # Clone the repository if it doesn't exist
    if not os.path.exists(LOCAL_DIR):
        !git clone {REPO_URL} {LOCAL_DIR}
    else:
        # Pull the latest changes if the repo already exists
        !cd {LOCAL_DIR} && git pull

    # Append the repository to the system path
    if LOCAL_DIR not in sys.path:
        sys.path.append(LOCAL_DIR + "/src")
    print(f"Repository at '{LOCAL_DIR}' is ready for imports.")

# Register the Python function as a callable function for JavaScript
from google.colab import output
output.register_callback('notebook.reload_repo', clone_and_import)

# Inject the button
inject_floating_button()

# Initial setup
clone_and_import()

remote: Enumerating objects: 14, done.[K
remote: Counting objects:   7% (1/14)[Kremote: Counting objects:  14% (2/14)[Kremote: Counting objects:  21% (3/14)[Kremote: Counting objects:  28% (4/14)[Kremote: Counting objects:  35% (5/14)[Kremote: Counting objects:  42% (6/14)[Kremote: Counting objects:  50% (7/14)[Kremote: Counting objects:  57% (8/14)[Kremote: Counting objects:  64% (9/14)[Kremote: Counting objects:  71% (10/14)[Kremote: Counting objects:  78% (11/14)[Kremote: Counting objects:  85% (12/14)[Kremote: Counting objects:  92% (13/14)[Kremote: Counting objects: 100% (14/14)[Kremote: Counting objects: 100% (14/14), done.[K
remote: Compressing objects:  16% (1/6)[Kremote: Compressing objects:  33% (2/6)[Kremote: Compressing objects:  50% (3/6)[Kremote: Compressing objects:  66% (4/6)[Kremote: Compressing objects:  83% (5/6)[Kremote: Compressing objects: 100% (6/6)[Kremote: Compressing objects: 100% (6/6), done.[K
remote: Total 10 (delt

In [16]:
%reload_ext autoreload
%autoreload 2

In [17]:
from functools import partial

from datasets import DatasetConfig
from logger.caveman import CavemanLogger
from logger.mlflow import MlFlowLogger
from models.vae import VAEConfig, mse_and_kld_loss
from utils.train import TrainConfig, standard_training_pipeline

import mlflow

In [None]:
model_config = VAEConfig(kld_weight=1)
dataset_config = DatasetConfig(
    root="/content/drive/MyDrive/Autoencoding Experiments/datasets/STL10",
    dataset_name="stl10",
    validation_split=0.1,
    batch_size=128,
)
train_config = TrainConfig("adam", learning_rate=0.0003, n_epochs=50)

experiment_name = "VAE on STL10"

# logger = CavemanLogger(f"../runs/{experiment_name}")
# logger = CavemanLogger(f"/content/drive/MyDrive/Autoencoding Experiments/{experiment_name}")

mlflow.end_run()
logger = MlFlowLogger(experiment_name=experiment_name, remote_url="https://mlflow.sniggles.de", debug=True)

standard_training_pipeline(
    model_config=model_config,
    dataset_config=dataset_config,
    train_config=train_config,
    logger=logger,
    loss_fn=partial(mse_and_kld_loss, model_config=model_config),
    validation_interval=1,
    test_interval=10,
    visualization_interval=1,
)

🏃 View run colorful-fox-784 at: https://mlflow.sniggles.de/#/experiments/387638843709741952/runs/b640328a97a84b46afdecd716bbe2b2f
🧪 View experiment at: https://mlflow.sniggles.de/#/experiments/387638843709741952
Setting tracking uri for mlflow
Created experiment with id 387638843709741952
Starting run
Files already downloaded and verified
Files already downloaded and verified
Logging configs.


  0%|          | 0/704 [00:00<?, ?it/s]

Logging kld_loss_weighted_train 0.26166348557241936
Logging kld_loss_train 0.26166348557241936
Logging mse_loss_train 0.035111318808048964
Logging total_loss_train 0.2967748041529293
Logging kld_loss_weighted_val 0.03987291205344321
Logging kld_loss_val 0.03987291205344321
Logging mse_loss_val 0.02359276703452762
Logging total_loss_val 0.06346567852210395
Logging kld_loss_weighted_test 0.028960284643939564
Logging kld_loss_test 0.028960284643939564
Logging mse_loss_test 0.01688556970348434
Logging total_loss_test 0.04584585414046333
Logging images for epoch 0
Logging images for epoch 0


  0%|          | 0/704 [00:00<?, ?it/s]