In [1]:
import os 
import pandas as pd
import numpy as np
from simsopt.field import Current
from simsopt.geo import SurfaceRZFourier
from simsopt._core import load
from pathlib import Path
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from concurrent.futures import ProcessPoolExecutor
from tqdm import tqdm

2025-06-27 22:31:15.456103: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-27 22:31:15.471163: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-27 22:31:15.475631: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [None]:
data_dir = Path('quasr_simsopt_files')
output_coil_dir = Path('coil_tfrecords')

MAX_COILS = 6
FEATURES_PER_COIL = 100

In [None]:
def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=value.flatten()))

def _int_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

def serialize_coil(id: str, coils: np.ndarray, coil_mask: np.ndarray):
    """
    coils: numpy array of shape (N+1, D), dtype float32
    """
    feature = {
        'ID': _bytes_feature(id.encode('utf-8')),
        'coil_data': _float_feature(coils),
        'coil_mask': _int_feature(coil_mask)
    }

    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [4]:
def process_coils(file_path):
    try:
        id = str(file_path)[-12:-5]
        surfaces, coils = load(str(file_path))
        s = surfaces[-1]

        num_coils = len(coils) // (s.nfp * 2)
        num_coils = min(num_coils, MAX_COILS)

        log_scaler = np.log10(coils[0].current.scale)
        scaler_token = np.full((1, FEATURES_PER_COIL), log_scaler, dtype=np.float32)

        coil_array = np.zeros((MAX_COILS + 1, FEATURES_PER_COIL), dtype=np.float32)
        for i in range(num_coils):
            params = coils[i].x[-99:]  # 99 Fourier + 1 current
            curr = np.array(coils[i].current.current_to_scale.current)
            coil_array[i] = np.append(curr, params)

        coil_array[-1] = scaler_token  # context token

        coil_mask = np.array([1] * num_coils + [0] * (MAX_COILS - num_coils), dtype=np.int64)

        return serialize_coil(id, coil_array, coil_mask)
        
    except Exception as e:
        print(f"Failed on {file_path}: {e}")
        return None    

In [None]:
def write_tfrecord_chunk(serialized_examples, output_path):
    with tf.io.TFRecordWriter(str(output_path)) as writer:
        for ex in serialized_examples:
            if ex:
                writer.write(ex)

def datasets_to_tfrecords(directory: Path, output_coil_dir: Path, 
                               chunk_size=10000, num_workers=64):
    files = list(directory.glob("*.json"))
    total_files = len(files)
    output_coil_dir.mkdir(parents=True, exist_ok=True)

    for i in range(0, total_files, chunk_size):
        chunk_files = files[i:i + chunk_size]
        with ProcessPoolExecutor(max_workers=num_workers) as executor:
            serialized_examples = list(tqdm(
                executor.map(process_coils, chunk_files),
                total=len(chunk_files),
                desc=f"Chunk {i//chunk_size:03d}"
            ))

        serialized_coils = [ex for ex in serialized_examples if ex is not None]

        output_coil_path = output_coil_dir / f"coils_chunk_{i//chunk_size:03d}.tfrecord"

        write_tfrecord_chunk(serialized_coils, output_coil_path)
        print(f"✅ Saved {len(serialized_coils)} coil samples to {output_coil_dir}")


In [None]:
datasets_to_tfrecords(directory=data_dir, output_coil_dir=output_coil_dir)

Chunk 000: 100%|██████████| 3000/3000 [00:05<00:00, 535.87it/s]


✅ Saved 3000 coil samples to coil_tfrecords
