In [1]:
!pip install -qqq ml-collections

In [None]:
import os
os.chdir("..")

import jax
import ml_collections

import pandas as pd

import glob
from datetime import datetime

import scripts.movielens_exp as movielens_run
import scripts.mnist_exp as mnist_run
import scripts.tabular_exp as tabular_run
import scripts.tabular_subspace_exp as tabular_sub_run

print(jax.device_count())

In [None]:
def get_config(filepath):
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.filepath = filepath
  config.ntrials = 10
  return config

In [None]:
timestamp = datetime.timestamp(datetime.now())

# Run tabular experiments

In [None]:
tabular_filename = f"./results/tabular_results_{timestamp}.csv"
config = get_config(tabular_filename)
tabular_run.main(config)

# Run MNIST experiments

In [None]:
mnist_filename = f"./results/mnist_results_{timestamp}.csv"
config = get_config(mnist_filename)
mnist_run.main(config)

# Run movielens experiments

In [None]:
movielens_filename = f"./results/movielens_results_{timestamp}.csv"
config = get_config(movielens_filename)
movielens_run.main(config)

# Run tabular subspace experiment

In [None]:
tabular_sub_filename = f"./results/tabular_subspace_results_{timestamp}.csv"
config = get_config(tabular_sub_filename)
tabular_sub_run.main(config)