In [1]:
import sys
import time
import os
import shutil
import json
import tensorflow as tf
from objax.util import EasyDict
import numpy as np

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../../'))
src_dir = os.path.join(project_dir, 'src')
data_dir = os.path.join(project_dir, 'data')
fig_dir = os.path.join(project_dir, 'fig')
logs_dir = os.path.join(project_dir, 'logs')
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

from train import get_data, network, MemModule



2026-01-05 11:45:43.650040: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-05 11:45:43.705720: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ============================================================================
# Training Parameters - Set these directly
# ============================================================================
# Dataset and architecture
dataset = 'cifar10'
arch = 'wrn28-2'

# Training configuration
epochs = 2
save_steps = 20
batch = 256
lr = 0.1
weight_decay = 0.0005
augment = 'weak'
pkeep = 0.5

# Experiment configuration
expid = 0
num_experiments = 2
seed = None  # Will be auto-generated if None

# Optional parameters
only_subset = None
patience = None
dataset_size = 1024
eval_steps = 1
tunename = False

base_logdir = os.path.join(logs_dir, 'exp', 'cifar10')
os.makedirs(base_logdir, exist_ok=True)

logdir_path = f"experiment-{expid}_{num_experiments}"
logdir_path = os.path.join(base_logdir, logdir_path)
if os.path.exists(logdir_path):
    shutil.rmtree(logdir_path)

os.makedirs(logdir_path, exist_ok=True)

# Create configuration dictionary for get_data
data_config = {
    'logdir': logs_dir,
    'dataset': dataset,
    'dataset_size': dataset_size,
    'num_experiments': num_experiments,
    'expid': expid,
    'pkeep': pkeep,
    'only_subset': only_subset,
    'augment': augment,
    'batch': batch,
    'data_dir': data_dir
}


In [3]:
# Disable GPU for TensorFlow (JAX will handle GPU)
tf.config.experimental.set_visible_devices([], "GPU")

if seed is None:
    seed = np.random.randint(0, 1000000000)
    seed ^= int(time.time())

train_config = EasyDict(
    arch=arch,
    lr=lr,
    batch=batch,
    weight_decay=weight_decay,
    augment=augment,
    seed=seed
)



In [5]:
train_data, test_data, xs, ys, keep, nclass = get_data(seed, data_config)

tm = MemModule(
    network(arch), 
    nclass=nclass,
    mnist=(dataset == 'mnist'),
    epochs=epochs,
    expid=expid,
    num_experiments=num_experiments,
    pkeep=pkeep,
    save_steps=save_steps,
    only_subset=only_subset,
    **train_config
)

# # Save hyperparameters
# params = {}
# params.update(tm.params)

# with open(os.path.join(logdir_path, 'hparams.json'), 'w') as f:
#     json.dump(params, f)
# np.save(os.path.join(logdir_path, 'keep.npy'), keep)

# # Train
# print("-" * 80)
# tm.train(epochs, len(xs), train_data, test_data, logdir_path,
#             save_steps=save_steps, patience=patience, eval_steps=eval_steps)

# print("-" * 80)
# print(f"✅ Training completed! Results saved to {logdir_path}")




In [6]:
# Minimal JAX GPU Check
import jax
import jax.numpy as jnp
import os

print("=" * 60)
print("JAX GPU Detection")
print("=" * 60)

# Check JAX version
print(f"\nJAX version: {jax.__version__}")

# List all devices
print(f"\nAll JAX devices: {jax.devices()}")

# Count devices by type
print(f"\nDevice count: {jax.device_count()}")
print(f"Local device count: {jax.local_device_count()}")

# Check for GPU specifically
gpu_devices = [d for d in jax.devices() if d.device_kind == 'gpu']
print(f"\nGPU devices: {len(gpu_devices)}")
if len(gpu_devices) > 0:
    print("✅ GPU is available!")
    for i, device in enumerate(gpu_devices):
        print(f"  GPU {i}: {device}")
else:
    print("❌ No GPU devices detected")

# Check environment variables
print("\n" + "=" * 60)
print("Environment Variables")
print("=" * 60)
print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', 'NOT SET')}")
print(f"XLA_PYTHON_CLIENT_PREALLOCATE: {os.environ.get('XLA_PYTHON_CLIENT_PREALLOCATE', 'NOT SET')}")

# Try a simple computation
print("\n" + "=" * 60)
print("GPU Computation Test")
print("=" * 60)

if len(gpu_devices) > 0:
    try:
        # Create arrays on GPU
        x = jnp.array([1.0, 2.0, 3.0])
        y = jnp.array([4.0, 5.0, 6.0])
        z = x + y
        
        print(f"✅ GPU computation successful!")
        print(f"   Input x: {x}")
        print(f"   Input y: {y}")
        print(f"   Result z: {z}")
        print(f"   Device: {z.device()}")
    except Exception as e:
        print(f"❌ GPU computation failed: {e}")
else:
    print("⚠️  Running on CPU (no GPU available)")
    x = jnp.array([1.0, 2.0, 3.0])
    y = jnp.array([4.0, 5.0, 6.0])
    z = x + y
    print(f"   CPU computation: {z}")
    print(f"   Device: {z.device()}")


JAX GPU Detection

JAX version: 0.4.13

All JAX devices: [CpuDevice(id=0)]

Device count: 1
Local device count: 1

GPU devices: 0
❌ No GPU devices detected

Environment Variables
CUDA_VISIBLE_DEVICES: NOT SET
XLA_PYTHON_CLIENT_PREALLOCATE: NOT SET

GPU Computation Test
⚠️  Running on CPU (no GPU available)
   CPU computation: [5. 7. 9.]
   Device: TFRT_CPU_0
