In [1]:
import os

# Suppress TensorFlow logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # Suppress all but fatal errors

# Optional: Disable oneDNN info message
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Optional: Disable XLA to reduce cu* factory warnings
os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_xla_devices=false"

In [2]:
import tensorflow as tf
import pandas as pd
import numpy as np
from pathlib import Path
tf.constant(1.0)  # Trigger basic op
import logging
logging.getLogger('absl').setLevel(logging.ERROR)
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import seaborn as sns
import ray

2025-08-05 23:20:05.973777: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754454005.990144  143922 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754454005.993994  143922 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754454006.004100  143922 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754454006.004126  143922 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754454006.004128  143922 computation_placer.cc:177] computation placer alr

In [3]:
%load_ext autoreload
%autoreload 2
from latent_loader import load_full_dataset, load_split_datasets
from transformers import TransformerEncoder, TransformerDecoder
from SurfaceEncoder import SurfaceEncoderModel

In [4]:
tfrecord_dir = Path("mini_latents_tfrecords")


In [None]:
def build_encoder(max_sets, features_per_set, embed_dim, num_heads, ff_dim, num_sab_blocks, dropout):
    input_surface = tf.keras.Input(shape=(max_sets+1, features_per_set), name='surface_data')
    mask = tf.keras.Input(shape=(max_sets,), dtype=tf.float32, name='surface_mask')
    
    pooled, _ = TransformerEncoder(input_surface, mask, 
                                embed_dim=embed_dim, num_heads=num_heads, ff_dim=ff_dim, 
                                dropout=dropout, num_sab_blocks=num_sab_blocks)
    
    return tf.keras.Model(inputs={"surface_data": input_surface, "surface_mask": mask}, outputs=pooled, name="surface_encoder")

In [6]:
def build_decoder(embed_dim, num_heads, ff_dim, num_layers, max_sets, features_per_set, dropout):
    encoded_input = tf.keras.Input(shape=(1, embed_dim), name='encoded_latent')
    
    decoded, _ = TransformerDecoder(encoded_input, embed_dim=embed_dim, num_heads=num_heads, 
                                    ff_dim=ff_dim, num_layers=num_layers, 
                                    max_sets=max_sets, features_per_set=features_per_set, 
                                    dropout=dropout)
    
    return tf.keras.Model(inputs=encoded_input, outputs=decoded, name="surface_decoder")

In [7]:
def build_surface_encoder(hp):
    encoder = build_encoder(
        max_sets=441, features_per_set=4, embed_dim=hp["embed_dim"],
        num_heads=hp["num_heads"], ff_dim=hp["ff_dim"], 
        num_sab_blocks=hp["sab_blocks"], dropout=hp["enc_dropout"]
    )
    
    decoder = build_decoder(
        embed_dim=hp["embed_dim"], num_heads=hp["num_heads"], ff_dim=hp["ff_dim"],
        num_layers=hp["decoder_blocks"], max_sets=441, features_per_set=4, dropout=hp["dec_dropout"]
    )
    
    model = SurfaceEncoderModel(encoder, decoder)
    return model, encoder, decoder

In [8]:
def sample_hyperparams():
    # return {
    #     'batch_size': np.random.choice([64, 128, 256]),
    #     "embed_dim": 64,
    #     "num_heads": np.random.choice([4, 6, 8]),
    #     "ff_dim": np.random.choice([128, 256, 512]),
    #     "enc_dropout": np.random.uniform(0.0, 0.3),
    #     "dec_dropout": np.random.uniform(0.0, 0.3),
    #     "learning_rate": 10 ** np.random.uniform(-5, np.log10(3)-3),
    #     'weight_decay': 10 ** np.random.uniform(np.log10(5)-3, -2),
    #     'sab_blocks': np.random.choice([1,2,3,4,5,6]),
    #     'decoder_blocks': np.random.choice([1,2,3,4,5,6])
    # }
    return {
        'batch_size': 64,
        "embed_dim": 64,
        "num_heads": 4,
        "ff_dim": 128,
        "enc_dropout": np.random.uniform(0.0, 0.3),
        "dec_dropout": np.random.uniform(0.0, 0.3),
        "learning_rate": 10 ** np.random.uniform(-5, np.log10(3)-3),
        'weight_decay': 10 ** np.random.uniform(np.log10(5)-3, -2),
        'sab_blocks': 1,
        'decoder_blocks': 1
    }

In [9]:
def build_model(hp):
    model, _, _ = build_surface_encoder(hp)    
    optimizer = tf.keras.optimizers.Lion(learning_rate=hp['learning_rate'], weight_decay=hp['weight_decay'])
    model.compile(optimizer=optimizer)
    return model

In [10]:
def train_trial(hp, train_dataset, val_dataset, trial_id=0, epochs=10, use_wandb=False):
    callbacks = []
    model = build_model(hp)
    
    # TensorBoard logging
    tb_logdir = f"logs/trial_{trial_id}"
    callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=tb_logdir))

    # Optional: wandb logging
    if use_wandb:
        import wandb
        from wandb.keras import WandbCallback
        wandb.init(project="coil_autoencoder", config=hp, name=f"trial_{trial_id}")
        callbacks.append(WandbCallback())
        #attn_map = attn_weights[0].numpy()  # (num_queries, encoded_len)
        #wandb.log({"attention_heatmap": wandb.Image(attn_map)})

    model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=callbacks
    )

    val_loss = model.evaluate(val_dataset)
    return val_loss

In [11]:
@ray.remote(num_gpus=0)  # Set num_gpus=0 if you want CPU-only runs
def parallel_train_trial(hp, trial_id, use_wandb):
    train, val, test = load_split_datasets(
        tfrecord_dir, batch_size=hp['batch_size'], train_frac=0.9, val_frac=0.1
    )
    val_loss = train_trial(hp, train, val, trial_id, use_wandb=use_wandb)
    return val_loss, hp


In [12]:
def run_random_search_parallel(n_trials=10, use_wandb=False):
    ray.init(ignore_reinit_error=True)

    result_refs = []
    for trial_id in range(n_trials):
        hp = sample_hyperparams()
        ref = parallel_train_trial.remote(hp, trial_id, use_wandb)
        result_refs.append(ref)

    results = ray.get(result_refs)  # This blocks until all trials are done

    # Sort by val_loss
    results.sort(key=lambda x: x[0][0])

    return results


In [13]:
def run_random_search(n_trials=10, use_wandb=True, save_dir="saved_models"):
    results = []
    os.makedirs(save_dir, exist_ok=True)

    for trial_id in range(n_trials):
        hp = sample_hyperparams()
        train, val, test = load_split_datasets(tfrecord_dir, batch_size=hp['batch_size'], train_frac=0.9)
        val_loss = train_trial(hp, train, val, trial_id, use_wandb=use_wandb)
        save_path = os.path.join(save_dir, f"trial_{trial_id:03d}_val_{val_loss[0]:.4f}")
        #model.save(save_path)
        results.append((val_loss, hp, save_path))

    results.sort(key=lambda x: x[0])
    return results


In [14]:
res = run_random_search(n_trials=1, use_wandb=False)

Epoch 1/10




     43/Unknown [1m74s[0m 2s/step - coil_latent_loss: 0.1152 - loss: 0.6350 - mae: 0.5179 - recon_loss: 0.5774 - scaler_loss: 0.2071 - unmasked_mse: 0.4640



[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m77s[0m 2s/step - coil_latent_loss: 0.1157 - loss: 0.6279 - mae: 0.5155 - recon_loss: 0.5700 - scaler_loss: 0.2052 - unmasked_mse: 0.4594 - val_coil_latent_loss: 0.1636 - val_loss: 0.2584 - val_mae: 0.3488 - val_recon_loss: 0.1766 - val_scaler_loss: 0.0432 - val_unmasked_mse: 0.1910
Epoch 2/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 2s/step - coil_latent_loss: 0.1433 - loss: 0.3086 - mae: 0.3771 - recon_loss: 0.2369 - scaler_loss: 0.0716 - unmasked_mse: 0.2274 - val_coil_latent_loss: 0.0878 - val_loss: 0.2144 - val_mae: 0.3464 - val_recon_loss: 0.1705 - val_scaler_loss: 0.0395 - val_unmasked_mse: 0.2030
Epoch 3/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m78s[0m 2s/step - coil_latent_loss: 0.0482 - loss: 0.2166 - mae: 0.3579 - recon_loss: 0.1925 - scaler_loss: 0.0633 - unmasked_mse: 0.2122 - val_coil_latent_loss: 0.0248 - val_loss: 0.1548 - val_mae: 0.2988 - val_recon_loss: 0.1423 - val_sc