In [1]:
import os

# Suppress TensorFlow logs
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"  # Suppress all but fatal errors

# Optional: Disable oneDNN info message
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"

# Optional: Disable XLA to reduce cu* factory warnings
os.environ["TF_XLA_FLAGS"] = "--tf_xla_enable_xla_devices=false"

In [2]:
import tensorflow as tf
import pandas as pd
import numpy as np
from pathlib import Path
tf.constant(1.0)  # Trigger basic op
import logging
logging.getLogger('absl').setLevel(logging.ERROR)
from tensorflow.keras import layers
import matplotlib.pyplot as plt
import seaborn as sns
import ray

2025-08-05 18:18:31.189963: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754435911.204641   77526 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754435911.208554   77526 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754435911.218658   77526 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754435911.218681   77526 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754435911.218683   77526 computation_placer.cc:177] computation placer alr

In [3]:
%load_ext autoreload
%autoreload 2
from surface_coil_loader import load_full_dataset, load_split_datasets
from transformers import TransformerEncoder, TransformerDecoder
from CoilAutoencoder import CoilAutoencoderModel

In [5]:
tfrecord_dir = Path("mini_surface_coil_tfrecords")

In [None]:
def build_encoder(max_coils, features_per_coil, embed_dim, num_heads, ff_dim, num_sab_blocks, dropout):
    input_coils = tf.keras.Input(shape=(max_coils+1, features_per_coil), name='coil_data')
    mask = tf.keras.Input(shape=(max_coils,), dtype=tf.float32, name='coil_mask')
    
    pooled, _ = TransformerEncoder(input_coils, mask, 
                                embed_dim=embed_dim, num_heads=num_heads, ff_dim=ff_dim, 
                                dropout=dropout, num_sab_blocks=num_sab_blocks)
    
    return tf.keras.Model(inputs={"coil_data": input_coils, "coil_mask": mask}, outputs=pooled, name="coil_encoder")

In [17]:
def build_decoder(embed_dim, num_heads, ff_dim, num_layers, max_coils, features_per_coil, dropout):
    encoded_input = tf.keras.Input(shape=(1, embed_dim), name='encoded_latent')
    
    decoded, _ = TransformerDecoder(encoded_input, embed_dim=embed_dim, num_heads=num_heads, 
                                    ff_dim=ff_dim, num_layers=num_layers, 
                                    max_sets=max_coils, features_per_set=features_per_coil, 
                                    dropout=dropout)
    
    return tf.keras.Model(inputs=encoded_input, outputs=decoded, name="coil_decoder")

In [8]:
def build_coil_autoencoder(hp):
    encoder = build_encoder(
        max_coils=6, features_per_coil=100, embed_dim=hp["embed_dim"],
        num_heads=hp["num_heads"], ff_dim=hp["ff_dim"], 
        num_sab_blocks=hp["sab_blocks"], dropout=hp["enc_dropout"]
    )
    
    decoder = build_decoder(
        embed_dim=hp["embed_dim"], num_heads=hp["num_heads"], ff_dim=hp["ff_dim"],
        num_layers=hp["decoder_blocks"], max_coils=6, features_per_coil=100, dropout=hp["dec_dropout"]
    )
    
    autoencoder = CoilAutoencoderModel(encoder, decoder)
    return autoencoder, encoder, decoder

In [9]:
def sample_hyperparams():
    return {
        'batch_size': np.random.choice([64, 128, 256]),
        "embed_dim": np.random.choice([64, 128, 256]),
        "num_heads": np.random.choice([4, 6, 8]),
        "ff_dim": np.random.choice([128, 256, 512]),
        "enc_dropout": np.random.uniform(0.0, 0.3),
        "dec_dropout": np.random.uniform(0.0, 0.3),
        "learning_rate": 10 ** np.random.uniform(-5, np.log10(3)-3),
        'weight_decay': 10 ** np.random.uniform(np.log10(5)-3, -2),
        'sab_blocks': np.random.choice([1,2,3,4,5,6]),
        'decoder_blocks': np.random.choice([1,2,3,4,5,6])
    }

In [10]:
def build_model(hp):
    model, _, _ = build_coil_autoencoder(hp)    
    optimizer = tf.keras.optimizers.Lion(learning_rate=hp['learning_rate'], weight_decay=hp['weight_decay'])
    model.compile(optimizer=optimizer)
    return model

In [11]:
def train_trial(hp, train_dataset, val_dataset, trial_id=0, epochs=10, use_wandb=False):
    callbacks = []
    model = build_model(hp)
    
    # TensorBoard logging
    tb_logdir = f"logs/trial_{trial_id}"
    callbacks.append(tf.keras.callbacks.TensorBoard(log_dir=tb_logdir))

    # Optional: wandb logging
    if use_wandb:
        import wandb
        from wandb.keras import WandbCallback
        wandb.init(project="coil_autoencoder", config=hp, name=f"trial_{trial_id}")
        callbacks.append(WandbCallback())
        #attn_map = attn_weights[0].numpy()  # (num_queries, encoded_len)
        #wandb.log({"attention_heatmap": wandb.Image(attn_map)})

    model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=callbacks
    )

    val_loss = model.evaluate(val_dataset)
    return val_loss

In [12]:
@ray.remote(num_gpus=0)  # Set num_gpus=0 if you want CPU-only runs
def parallel_train_trial(hp, trial_id, use_wandb):
    train, val, test = load_split_datasets(
        tfrecord_dir, batch_size=hp['batch_size'], train_frac=0.9, val_frac=0.1
    )
    val_loss = train_trial(hp, train, val, trial_id, use_wandb=use_wandb)
    return val_loss, hp


In [13]:
def run_random_search_parallel(n_trials=10, use_wandb=False):
    ray.init(ignore_reinit_error=True)

    result_refs = []
    for trial_id in range(n_trials):
        hp = sample_hyperparams()
        ref = parallel_train_trial.remote(hp, trial_id, use_wandb)
        result_refs.append(ref)

    results = ray.get(result_refs)  # This blocks until all trials are done

    # Sort by val_loss
    results.sort(key=lambda x: x[0][0])

    return results


In [14]:
def run_random_search(n_trials=10, use_wandb=True, save_dir="saved_models"):
    results = []
    os.makedirs(save_dir, exist_ok=True)

    for trial_id in range(n_trials):
        hp = sample_hyperparams()
        train, val, test = load_split_datasets(tfrecord_dir, batch_size=hp['batch_size'], train_frac=0.9)
        val_loss = train_trial(hp, train, val, trial_id, use_wandb=use_wandb)
        save_path = os.path.join(save_dir, f"trial_{trial_id:03d}_val_{val_loss[0]:.4f}")
        #model.save(save_path)
        results.append((val_loss, hp, save_path))

    results.sort(key=lambda x: x[0])
    return results


In [18]:
res = run_random_search(n_trials=2, use_wandb=False)



Epoch 1/10


Expected: {'coil_data': 'coil_data', 'coil_mask': 'coil_mask'}
Received: inputs={'coil_data': 'Tensor(shape=(None, None, 100))', 'coil_mask': 'Tensor(shape=(None, None))', 'surface_data': 'Tensor(shape=(None, None, 4))', 'surface_mask': 'Tensor(shape=(None, None))'}


     43/Unknown [1m13s[0m 129ms/step - coil_loss: 0.2087 - loss: 1.1691 - mae: 0.2569 - scaler_loss: 0.9604 - unmasked_mse: 0.4729



[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 148ms/step - coil_loss: 0.2076 - loss: 1.1465 - mae: 0.2543 - scaler_loss: 0.9389 - unmasked_mse: 0.4710 - val_coil_loss: 0.1802 - val_loss: 0.1845 - val_mae: 0.1294 - val_scaler_loss: 0.0043 - val_unmasked_mse: 0.3920
Epoch 2/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 134ms/step - coil_loss: 0.1721 - loss: 0.1781 - mae: 0.1406 - scaler_loss: 0.0060 - unmasked_mse: 0.3844 - val_coil_loss: 0.1761 - val_loss: 0.1763 - val_mae: 0.1273 - val_scaler_loss: 1.9236e-04 - val_unmasked_mse: 0.3779
Epoch 3/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 127ms/step - coil_loss: 0.1725 - loss: 0.1782 - mae: 0.1513 - scaler_loss: 0.0056 - unmasked_mse: 0.3851 - val_coil_loss: 0.1797 - val_loss: 0.1809 - val_mae: 0.1494 - val_scaler_loss: 0.0013 - val_unmasked_mse: 0.3867
Epoch 4/10
[1m43/43[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 124ms/step - coil_loss: 0.1714 - loss: 0.1842 - m

In [17]:
res

[([<tf.Tensor: shape=(), dtype=float32, numpy=0.17599551379680634>,
   <tf.Tensor: shape=(), dtype=float32, numpy=0.1227232962846756>,
   <tf.Tensor: shape=(), dtype=float32, numpy=0.3796854019165039>,
   <tf.Tensor: shape=(), dtype=float32, numpy=0.05327221378684044>,
   <tf.Tensor: shape=(), dtype=float32, numpy=0.27929821610450745>],
  {'batch_size': 128,
   'embed_dim': 256,
   'num_heads': 4,
   'ff_dim': 256,
   'enc_dropout': 0.2549885360880258,
   'dec_dropout': 0.04841201638431631,
   'learning_rate': 2.759230028803988e-05,
   'weight_decay': 0.007528391760608386,
   'sab_blocks': 1,
   'decoder_blocks': 6},
  'saved_models/trial_001_val_0.1760'),
 ([<tf.Tensor: shape=(), dtype=float32, numpy=0.24534067511558533>,
   <tf.Tensor: shape=(), dtype=float32, numpy=0.1677953153848648>,
   <tf.Tensor: shape=(), dtype=float32, numpy=0.29536762833595276>,
   <tf.Tensor: shape=(), dtype=float32, numpy=0.07754535228013992>,
   <tf.Tensor: shape=(), dtype=float32, numpy=0.4115840196609497

In [14]:
hp = sample_hyperparams()
model = build_model(hp)



In [15]:
dataset = load_coil_dataset(tfrecord_dir, batch_size=64)
for input, _ in dataset.take(1):
    output = model(inputs = {'coil_data': input['coil_data'], 'coil_mask': input['coil_mask']})  # output: (B, N, D)

hi


In [21]:
data = dataset.take(1)
data

<_TakeDataset element_spec=({'coil_data': TensorSpec(shape=(None, 7, 100), dtype=tf.float32, name=None), 'coil_mask': TensorSpec(shape=(None, 6), dtype=tf.int64, name=None)}, TensorSpec(shape=(None, 6, 100), dtype=tf.float32, name=None))>

In [None]:
top_trials = sorted(results, key=lambda x: x[0])[:20]  # val_loss ascending
top_configs = [dict(config) for _, config, _ in top_trials]

In [None]:
bayesian_sweep_config = {
    "method": "bayes",
    "metric": {
        "name": "val_loss",
        "goal": "minimize"
    },
    "parameters": {
        "embed_dim": {"values": [64, 128, 256]},
        "num_heads": {"values": [4, 6, 8]},
        "ff_dim": {"values": [128, 256, 512]},
        "dropout": {"min": 0.0, "max": 0.3},
        "learning_rate": {"distribution": "log_uniform_values", "min": 1e-4, "max": 1e-3}
    },
    "early_terminate": {"type": "hyperband", "min_iter": 5},
    "initial_points": top_configs
}

In [None]:
def sweep_train_fn():
    import wandb
    from wandb.keras import WandbCallback

    wandb.init(project="coil_autoencoder")
    config = wandb.config

    hp = dict(
        embed_dim=config.embed_dim,
        num_heads=config.num_heads,
        ff_dim=config.ff_dim,
        dropout=config.dropout,
        learning_rate=config.learning_rate
    )

    model = build_coil_autoencoder(
        embed_dim=hp["embed_dim"],
        num_heads=hp["num_heads"],
        ff_dim=hp["ff_dim"],
        max_coils=6,
        features_per_coil=100,
        dropout=hp["dropout"]
    )
    model.compile(optimizer= tf.keras.optimizers.Lion(hp["learning_rate"]))

    callbacks = [WandbCallback()]
    model.fit(train_ds, validation_data=val_ds, epochs=30, callbacks=callbacks)

    val_loss = model.evaluate(val_ds)
    wandb.log({"val_loss": val_loss})

    # Save model
    model.save(f"saved_models/wandb_trial_{wandb.run.id}")


In [None]:
sweep_id = wandb.sweep(bayesian_sweep_config, project="coil_autoencoder")
wandb.agent(sweep_id, function=sweep_train_fn, count=40)

In [None]:
#visualize learned queries
def plot_learned_queries(model, layer_name="learned_query_decoder"):
    decoder_layer = model.get_layer(layer_name)
    queries = decoder_layer.learned_queries.numpy()  # shape (N, D)

    plt.figure(figsize=(10, 6))
    sns.heatmap(queries, cmap="viridis", cbar=True)
    plt.xlabel("Embedding Dimension")
    plt.ylabel("Query Index")
    plt.title("Learned Queries (Coil Decoder)")
    plt.show()


In [None]:
attn_out, attn_scores = self.attn(queries, encoded_set, return_attention_scores=True)

def visualize_attention(attn_scores, query_idx=0):
    """
    attn_scores: shape (B, num_queries, seq_len)
    """
    plt.figure(figsize=(10, 4))
    sns.heatmap(attn_scores[0, query_idx], cmap="magma")
    plt.title(f"Attention Weights for Query {query_idx}")
    plt.xlabel("Encoded Coil Index")
    plt.ylabel("Head")
    plt.show()