In [1]:
import os
import sys
import time

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../../'))
src_dir = os.path.join(project_dir, 'src')
data_dir = os.path.join(project_dir, 'data')
fig_dir = os.path.join(project_dir, 'fig')
logs_dir = os.path.join(project_dir, 'logs')
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

import numpy as np

In [2]:
import os
import shutil
import json
import time
import tensorflow as tf
from objax.util import EasyDict
from absl import flags

# Define flags that train.py uses (normally defined in if __name__ == '__main__')
flags.DEFINE_string('arch', 'cnn32-3-mean', 'Model architecture.')
flags.DEFINE_float('lr', 0.1, 'Learning rate.')
flags.DEFINE_string('dataset', 'cifar10', 'Dataset.')
flags.DEFINE_float('weight_decay', 0.0005, 'Weight decay ratio.')
flags.DEFINE_integer('batch', 256, 'Batch size')
flags.DEFINE_integer('epochs', 501, 'Training duration in number of epochs.')
flags.DEFINE_string('logdir', 'experiments', 'Directory where to save checkpoints and tensorboard data.')
flags.DEFINE_integer('seed', None, 'Training seed.')
flags.DEFINE_float('pkeep', .5, 'Probability to keep examples.')
flags.DEFINE_integer('expid', None, 'Experiment ID')
flags.DEFINE_integer('num_experiments', None, 'Number of experiments')
flags.DEFINE_string('augment', 'weak', 'Strong or weak augmentation')
flags.DEFINE_integer('only_subset', None, 'Only train on a subset of images.')
flags.DEFINE_integer('dataset_size', 50000, 'number of examples to keep.')
flags.DEFINE_integer('eval_steps', 1, 'how often to get eval accuracy.')
flags.DEFINE_integer('abort_after_epoch', None, 'stop trainin early at an epoch')
flags.DEFINE_integer('save_steps', 10, 'how often to get save model.')
flags.DEFINE_integer('patience', None, 'Early stopping after this many epochs without progress')
flags.DEFINE_bool('tunename', False, 'Use tune name?')


from train import get_data, network, MemModule

FLAGS = flags.FLAGS

# Parse flags (required before accessing FLAGS values)
# Mark flags as parsed so we can access them without command-line args
FLAGS.mark_as_parsed()

# ============================================================================
# Training Parameters - Set these directly
# ============================================================================
# Dataset and architecture
dataset = 'cifar10'
arch = 'wrn28-2'

# Training configuration
epochs = 100
save_steps = 20
batch = 256
lr = 0.1
weight_decay = 0.0005
augment = 'weak'
pkeep = 0.5

# Experiment configuration
expid = 3
num_experiments = 16
seed = None  # Will be auto-generated if None

# Optional parameters
only_subset = None
patience = None
dataset_size = 50000
eval_steps = 1
tunename = False

# ============================================================================
# Training Logic (from train.py main function)
# ============================================================================

# Disable GPU for TensorFlow (JAX will handle GPU)
tf.config.experimental.set_visible_devices([], "GPU")

# Set seed
if seed is None:
    seed = np.random.randint(0, 1000000000)
    seed ^= int(time.time())

# Create args dictionary
args = EasyDict(
    arch=arch,
    lr=lr,
    batch=batch,
    weight_decay=weight_decay,
    augment=augment,
    seed=seed
)

assert expid is not None and num_experiments is not None

base_logdir = os.path.join(logs_dir, 'exp', 'cifar10')
os.makedirs(base_logdir, exist_ok=True)

logdir_path = f"experiment-{expid}_{num_experiments}"
logdir_path = os.path.join(base_logdir, logdir_path)

if os.path.exists(os.path.join(logdir_path, "ckpt", f"{epochs:010d}.npz")):
    print(f"Run {expid} already completed.")
else:
    if os.path.exists(logdir_path):
        print(f"Deleting run {expid} that did not complete.")
        shutil.rmtree(logdir_path)
    
    print(f"Creating experiment directory: {logdir_path}")
    os.makedirs(logdir_path, exist_ok=True)

2026-01-05 18:20:31.959953: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-05 18:20:32.016387: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


Deleting run 3 that did not complete.
Creating experiment directory: /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/lira_attack/logs/exp/cifar10/experiment-3_16


In [3]:
# Create configuration dictionary for get_data
config = {
    'logdir': logs_dir,
    'dataset': dataset,
    'dataset_size': dataset_size,
    'num_experiments': num_experiments,
    'expid': expid,
    'pkeep': pkeep,
    'only_subset': only_subset,
    'augment': augment,
    'batch': batch,
    'data_dir': data_dir
}

# Get data - pass config dictionary
train_data, test_data, xs, ys, keep, nclass = get_data(seed, config)

# Define the network and training module
tm = MemModule(
    network(arch), 
    nclass=nclass,
    mnist=(dataset == 'mnist'),
    epochs=epochs,
    expid=expid,
    num_experiments=num_experiments,
    pkeep=pkeep,
    save_steps=save_steps,
    only_subset=only_subset,
    **args
)

# Save hyperparameters
params = {}
params.update(tm.params)

with open(os.path.join(logdir_path, 'hparams.json'), 'w') as f:
    json.dump(params, f)
np.save(os.path.join(logdir_path, 'keep.npy'), keep)

# Train
print("-" * 80)
tm.train(epochs, len(xs), train_data, test_data, logdir_path,
            save_steps=save_steps, patience=patience, eval_steps=eval_steps)

print("-" * 80)
print(f"✅ Training completed! Results saved to {logdir_path}")


--------------------------------------------------------------------------------
No checkpoints found. Skipping restoring variables.
Epoch 0001  Loss 1.84  Accuracy 20.87
Epoch 0002  Loss 1.35  Accuracy 27.72
Epoch 0003  Loss 1.07  Accuracy 41.02
Epoch 0004  Loss 0.94  Accuracy 51.19
Epoch 0005  Loss 0.83  Accuracy 58.73
Epoch 0006  Loss 0.74  Accuracy 63.18
Epoch 0007  Loss 0.68  Accuracy 66.55
Epoch 0008  Loss 0.63  Accuracy 69.64
Epoch 0009  Loss 0.59  Accuracy 71.63
Epoch 0010  Loss 0.56  Accuracy 73.72
Epoch 0011  Loss 0.51  Accuracy 75.35
Epoch 0012  Loss 0.49  Accuracy 76.94
Epoch 0013  Loss 0.48  Accuracy 78.18
Epoch 0014  Loss 0.44  Accuracy 79.42
Epoch 0015  Loss 0.45  Accuracy 80.21
Epoch 0016  Loss 0.42  Accuracy 81.26
Epoch 0017  Loss 0.41  Accuracy 82.10
Epoch 0018  Loss 0.39  Accuracy 82.94
Epoch 0019  Loss 0.38  Accuracy 83.74
Epoch 0020  Loss 0.37  Accuracy 84.37
Epoch 0021  Loss 0.35  Accuracy 84.89
Epoch 0022  Loss 0.34  Accuracy 85.43
Epoch 0023  Loss 0.34  Accuracy