In [1]:
import os
import sys
import time
import shutil
import json
import tensorflow as tf
from objax.util import EasyDict
from absl import flags

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../../'))
src_dir = os.path.join(project_dir, 'src')
data_dir = os.path.join(project_dir, 'data')
fig_dir = os.path.join(project_dir, 'fig')
logs_dir = os.path.join(project_dir, 'logs')
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

import numpy as np

2026-01-05 20:11:41.012959: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-05 20:11:41.069417: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from train import get_data, network, MemModule

FLAGS = flags.FLAGS

# ============================================================================
# Training Parameters - Set these directly
# ============================================================================
# Dataset and architecture
dataset = 'cifar10'
arch = 'wrn28-2'

# Training configuration
epochs = 100
save_steps = 20
batch = 256
lr = 0.1
weight_decay = 0.0005
augment = 'weak'
pkeep = 0.5

# Experiment configuration
expid = 3
num_experiments = 16
seed = None  # Will be auto-generated if None

# Optional parameters
only_subset = None
patience = None
dataset_size = 50000
eval_steps = 1
tunename = False

# ============================================================================
# Training Logic (from train.py main function)
# ============================================================================

# Disable GPU for TensorFlow (JAX will handle GPU)
tf.config.experimental.set_visible_devices([], "GPU")

# Set seed
if seed is None:
    seed = np.random.randint(0, 1000000000)
    seed ^= int(time.time())

# Create args dictionary
args = EasyDict(
    arch=arch,
    lr=lr,
    batch=batch,
    weight_decay=weight_decay,
    augment=augment,
    seed=seed
)

assert expid is not None and num_experiments is not None

base_logdir = os.path.join(logs_dir, 'exp', 'cifar10')
os.makedirs(base_logdir, exist_ok=True)

logdir_path = f"experiment-{expid}_{num_experiments}"
logdir_path = os.path.join(base_logdir, logdir_path)

if os.path.exists(os.path.join(logdir_path, "ckpt", f"{epochs:010d}.npz")):
    print(f"Run {expid} already completed.")
else:
    if os.path.exists(logdir_path):
        print(f"Deleting run {expid} that did not complete.")
        shutil.rmtree(logdir_path)
    
    print(f"Creating experiment directory: {logdir_path}")
    os.makedirs(logdir_path, exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm


Run 3 already completed.


In [3]:
# Create configuration dictionary for get_data
config = {
    'logdir': logs_dir,
    'dataset': dataset,
    'dataset_size': dataset_size,
    'num_experiments': num_experiments,
    'expid': expid,
    'pkeep': pkeep,
    'only_subset': only_subset,
    'augment': augment,
    'batch': batch,
    'data_dir': data_dir
}

# Get data - pass config dictionary
train_data, test_data, xs, ys, keep, nclass = get_data(seed, config)

# Define the network and training module
tm = MemModule(
    network(arch), 
    nclass=nclass,
    mnist=(dataset == 'mnist'),
    epochs=epochs,
    expid=expid,
    num_experiments=num_experiments,
    pkeep=pkeep,
    save_steps=save_steps,
    only_subset=only_subset,
    **args
)

# Save hyperparameters
params = {}
params.update(tm.params)

with open(os.path.join(logdir_path, 'hparams.json'), 'w') as f:
    json.dump(params, f)
np.save(os.path.join(logdir_path, 'keep.npy'), keep)

# # Train
# print("-" * 80)
# tm.train(epochs, len(xs), train_data, test_data, logdir_path,
#             save_steps=save_steps, patience=patience, eval_steps=eval_steps)

# print("-" * 80)
# print(f"âœ… Training completed! Results saved to {logdir_path}")
