In [1]:
import os
import sys
import time
import shutil
import json
import tensorflow as tf
from objax.util import EasyDict
from absl import flags

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../../'))
src_dir = os.path.join(project_dir, 'src')
data_dir = os.path.join(project_dir, 'data')
fig_dir = os.path.join(project_dir, 'fig')
logs_dir = os.path.join(project_dir, 'logs')
os.makedirs(fig_dir, exist_ok=True)
os.makedirs(data_dir, exist_ok=True)
os.makedirs(logs_dir, exist_ok=True)

# Add the src directory to sys.path
sys.path.append(src_dir)

import numpy as np

2026-01-05 20:37:51.000794: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-05 20:37:51.057575: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
import re
import objax
from train import MemModule, network

# ============================================================================
# Inference Parameters 
# ============================================================================
dataset = 'cifar10'
dataset_size = 50000
base_logdir = os.path.join(logs_dir, 'exp', dataset)
os.makedirs(base_logdir, exist_ok=True)
regex_pattern = '.*experiment.*'

# Epoch configuration
from_epoch = None  # If None, will use max_epoch-1. Otherwise, start from this epoch

# Batch size for processing
batch_size = 5000

# ============================================================================
# Helper Functions
# ============================================================================

# Disable GPU for TensorFlow (JAX will handle GPU)
tf.config.experimental.set_visible_devices([], "GPU")

def load_model(arch, nclass, mnist=False):
    """Load a model with the given architecture."""
    return MemModule(
        network(arch), 
        nclass=nclass,
        mnist=mnist,
        arch=arch,
        lr=0.1,
        batch=0,
        epochs=0,
        weight_decay=0
    )

def get_loss(model, xbatch, ybatch, shift=0, reflect=True, stride=1):
    """
    Generate logits for a batch with data augmentation.
    
    Args:
        model: The model to use for inference
        xbatch: Input batch (N, H, W, C)
        ybatch: Labels (not used, but kept for compatibility)
        shift: Shift amount for augmentation (default: 0)
        reflect: Whether to use reflection augmentation (default: True)
        stride: Stride for augmentation (default: 1)
    
    Returns:
        Logits array of shape (N, num_augs, nclass)
    """
    outs = []
    for aug in [xbatch, xbatch[:,:,::-1,:]][:reflect+1]:
        aug_pad = tf.pad(aug, [[0] * 2, [shift] * 2, [shift] * 2, [0] * 2], mode='REFLECT').numpy()
        for dx in range(0, 2*shift+1, stride):
            for dy in range(0, 2*shift+1, stride):
                this_x = aug_pad[:, dx:dx+32, dy:dy+32, :].transpose((0,3,1,2))
                logits = model.model(this_x, training=True)
                outs.append(logits)
    
    print(f"Generated logits shape: {np.array(outs).shape}")
    return np.array(outs).transpose((1, 0, 2))

def features(model, xbatch, ybatch):
    """Generate features (logits) for a batch."""
    return get_loss(model, xbatch, ybatch, shift=0, reflect=True, stride=1)


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# ============================================================================
# Load Training Data
# ============================================================================
# Load the training data that was used for training
# This should be in the base logdir (where x_train.npy and y_train.npy are stored)
data_path = os.path.join(base_logdir, "x_train.npy")
labels_path = os.path.join(base_logdir, "y_train.npy")

if not os.path.exists(data_path):
    # Try alternative location (in logs_dir root)
    data_path = os.path.join(logs_dir, "x_train.npy")
    labels_path = os.path.join(logs_dir, "y_train.npy")

if os.path.exists(data_path) and os.path.exists(labels_path):
    xs_all = np.load(data_path)[:dataset_size]
    ys_all = np.load(labels_path)[:dataset_size]
    print(f"‚úÖ Loaded {len(xs_all)} training examples")
    print(f"   Data shape: {xs_all.shape}")
    print(f"   Labels shape: {ys_all.shape}")
else:
    raise FileNotFoundError(
        f"Could not find x_train.npy and y_train.npy. "
        f"Tried: {os.path.join(base_logdir, 'x_train.npy')} and {os.path.join(logs_dir, 'x_train.npy')}"
    )

# Determine number of classes
assert dataset in ['cifar10', 'cifar100', 'mnist']
nclass = 100 if dataset == 'cifar100' else 10
mnist = (dataset == 'mnist')

print(f"Dataset: {dataset}, Number of classes: {nclass}")


‚úÖ Loaded 50000 training examples
   Data shape: (50000, 32, 32, 3)
   Labels shape: (50000,)
Dataset: cifar10, Number of classes: 10


In [4]:
# ============================================================================
# Main Inference Loop
# ============================================================================
# Model cache to avoid reloading models with the same architecture
model_cache = {}

def get_cached_model(arch):
    """Get or create a cached model for the given architecture."""
    if arch not in model_cache:
        print(f"Loading model with architecture: {arch}")
        model_cache[arch] = load_model(arch, nclass, mnist)
    return model_cache[arch]

# Process each experiment directory
experiment_dirs = sorted([d for d in os.listdir(base_logdir) 
                          if os.path.isdir(os.path.join(base_logdir, d))])

print(f"Found {len(experiment_dirs)} directories in {base_logdir}")
print(f"Filtering with regex: {regex_pattern}")

processed_count = 0
skipped_count = 0

for exp_dir in experiment_dirs:
    # Check if directory matches regex
    if re.search(regex_pattern, exp_dir) is None:
        print(f"Skipping '{exp_dir}' (doesn't match regex)")
        skipped_count += 1
        continue
    
    exp_path = os.path.join(base_logdir, exp_dir)
    hparams_path = os.path.join(exp_path, "hparams.json")
    
    if not os.path.exists(hparams_path):
        print(f"Skipping '{exp_dir}' (no hparams.json found)")
        skipped_count += 1
        continue
    
    # Load hyperparameters to get architecture
    with open(hparams_path, 'r') as f:
        hparams = json.load(f)
    arch = hparams['arch']
    
    print("=" * 80)
    print(f"Processing experiment: {exp_dir}")
    print(f"Architecture: {arch}")
    print("=" * 80)
    
    # Get or create model
    model = get_cached_model(arch)
    
    # Set up checkpoint directory
    checkpoint = objax.io.Checkpoint(exp_path, keep_ckpts=10, makedir=True)
    max_epoch, last_ckpt = checkpoint.restore(model.vars())
    
    if max_epoch == 0:
        print(f"  ‚ö†Ô∏è  No checkpoints found for {exp_dir}, skipping")
        skipped_count += 1
        continue
    
    print(f"  Found checkpoints up to epoch {max_epoch}")
    
    # Create logits directory if it doesn't exist
    logits_dir = os.path.join(exp_path, "logits")
    os.makedirs(logits_dir, exist_ok=True)
    
    # Determine starting epoch
    if from_epoch is not None:
        first_epoch = from_epoch
    else:
        first_epoch = max_epoch - 1
    
    print(f"  Processing epochs from {first_epoch} to {max_epoch}")
    
    # Process each epoch
    for epoch in range(first_epoch, max_epoch + 1):
        ckpt_path = os.path.join(exp_path, "ckpt", f"{epoch:010d}.npz")
        logits_path = os.path.join(logits_dir, f"{epoch:010d}.npy")
        
        # Skip if checkpoint doesn't exist
        if not os.path.exists(ckpt_path):
            continue
        
        # Skip if logits already generated
        if os.path.exists(logits_path):
            print(f"  ‚è≠Ô∏è  Skipping epoch {epoch} (logits already exist)")
            continue
        
        # Load checkpoint
        try:
            start_epoch, last_ckpt = checkpoint.restore(model.vars(), epoch)
            print(f"  üìä Generating logits for epoch {epoch}...")
        except Exception as e:
            print(f"  ‚ùå Failed to load checkpoint for epoch {epoch}: {e}")
            continue
        
        # Generate logits in batches
        stats = []
        for i in range(0, len(xs_all), batch_size):
            batch_x = xs_all[i:i+batch_size]
            batch_y = ys_all[i:i+batch_size]
            batch_logits = features(model, batch_x, batch_y)
            stats.extend(batch_logits)
        
        # Save logits
        # Shape will be (N, 1, num_augs, nclass) to match original format
        logits_array = np.array(stats)[:, None, :, :]
        np.save(logits_path, logits_array)
        print(f"  ‚úÖ Saved logits for epoch {epoch} (shape: {logits_array.shape})")
    
    processed_count += 1
    print()

print("=" * 80)
print(f"‚úÖ Inference completed!")
print(f"   Processed: {processed_count} experiments")
print(f"   Skipped: {skipped_count} experiments")
print("=" * 80)


Found 16 directories in /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/lira_attack/logs/exp/cifar10
Filtering with regex: .*experiment.*
Processing experiment: experiment-0_16
Architecture: wrn28-2
Loading model with architecture: wrn28-2
Resuming from /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/lira_attack/logs/exp/cifar10/experiment-0_16/ckpt/0000000100.npz
  Found checkpoints up to epoch 100
  Processing epochs from 99 to 100
  ‚è≠Ô∏è  Skipping epoch 100 (logits already exist)

Processing experiment: experiment-10_16
Architecture: wrn28-2
Resuming from /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/lira_attack/logs/exp/cifar10/experiment-10_16/ckpt/0000000100.npz
  Found checkpoints up to epoch 100
  Processing epochs from 99 to 100
  ‚è≠Ô∏è  Skipping epoch 100 (logits already exist)

Processing experiment: experiment-11_16
Architecture: wrn28-2
Resuming from /storage/coda1/p-vzikas3/0/ywei368/Yu-Project/Auditing/lira_attack/logs/exp/cifar10/experiment-11_

In [5]:
# ============================================================================
# Compute Scores from Logits
# ============================================================================
# Use the functions from score.py directly
from score import load_one

# Set up the logdir and labels for score.py (it uses global variables)
import score as score_module
score_module.logdir = base_logdir
score_module.labels = ys_all

# Process all experiment directories matching regex
experiment_dirs = sorted([d for d in os.listdir(base_logdir) 
                          if os.path.isdir(os.path.join(base_logdir, d))])
matching_dirs = [d for d in experiment_dirs if re.search(regex_pattern, d)]

print("=" * 80)
print("Computing scores from logits")
print(f"Found {len(matching_dirs)} experiment directories")
print("=" * 80)

for exp_dir in matching_dirs:
    print(f"Processing: {exp_dir}")
    load_one(exp_dir)

print("=" * 80)
print("‚úÖ Score computation completed!")
print("=" * 80)


Computing scores from logits
Found 16 experiment directories
Processing: experiment-0_16
(50000, 1, 2)
mean acc 0.95478
Processing: experiment-10_16
(50000, 1, 2)
mean acc 0.9532
Processing: experiment-11_16
(50000, 1, 2)
mean acc 0.953
Processing: experiment-12_16
(50000, 1, 2)
mean acc 0.95402
Processing: experiment-13_16
(50000, 1, 2)
mean acc 0.95152
Processing: experiment-14_16
(50000, 1, 2)
mean acc 0.95402
Processing: experiment-15_16
(50000, 1, 2)
mean acc 0.95206
Processing: experiment-1_16
(50000, 1, 2)
mean acc 0.95634
Processing: experiment-2_16
(50000, 1, 2)
mean acc 0.95366
Processing: experiment-3_16
(50000, 1, 2)
mean acc 0.95286
Processing: experiment-4_16
(50000, 1, 2)
mean acc 0.95384
Processing: experiment-5_16
(50000, 1, 2)
mean acc 0.95192
Processing: experiment-6_16
(50000, 1, 2)
mean acc 0.95062
Processing: experiment-7_16
(50000, 1, 2)
mean acc 0.95442
Processing: experiment-8_16
(50000, 1, 2)
mean acc 0.95206
Processing: experiment-9_16
(50000, 1, 2)
mean acc 