In [95]:
import flax.linen as nn
import jax.numpy as jnp
import os
import jax
from jax import random
import numpy as np

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
devices = jax.local_devices()

print("Backend Selected:", jax.lib.xla_bridge.get_backend().platform)
print("Detected Devices:", jax.devices())

Backend Selected: gpu
Detected Devices: [CudaDevice(id=0)]


  print("Backend Selected:", jax.lib.xla_bridge.get_backend().platform)


In [96]:
# data config
batch_size = 100
signal_length = 1000

In [97]:
#making sample data

def split_complex_to_imaginary(complex_array):
    """Splits a complex array into real and imaginary parts and flattens it."""
    return np.concatenate([complex_array.real, complex_array.imag], axis=-1)

num_samples = 30

# Test sample data
signal_data = [np.random.randn(1000) + 1j * np.random.randn(1000) for _ in range(num_samples)]
coefficients_data = [np.random.randn(6) + 1j * np.random.randn(6) for _ in range(num_samples)]

In [98]:
# dataloader (shouldn't use pytorch DataLoader objects? build my own fast one)

# Each signal has length 1000, coefficients have length 6
signal_data_real_imag = [split_complex_to_imaginary(signal) for signal in signal_data]  # Now each is length 2000
coefficients_data_real_imag = [split_complex_to_imaginary(coeff) for coeff in coefficients_data]  # Now each is length 12

# Create dataset as pairs of (signal, coefficients)
dataset = list(zip(signal_data_real_imag, coefficients_data_real_imag))

def data_loader(dataset, batch_size, shuffle=True):
    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    
    # Shuffle dataset if required
    if shuffle:
        np.random.shuffle(indices)
    
    # Loop over dataset and yield batches
    for start_idx in range(0, dataset_size, batch_size):
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_indices = indices[start_idx:end_idx]
        
        # Extract the batch of signals and coefficients separately
        batch_signal = [dataset[i][0] for i in batch_indices]  # Signal of length 2000
        batch_coefficients = [dataset[i][1] for i in batch_indices]  # Coefficients of length 12
        
        # Convert the batch data to JAX arrays
        yield jnp.array(batch_signal), jnp.array(batch_coefficients)


for batch_signal, batch_coefficients in data_loader(dataset, batch_size):
    print(f"Batch signal shape: {batch_signal.shape}")  # Should be (32, 2000)
    print(f"Batch coefficients shape: {batch_coefficients.shape}")  # Should be (32, 12)

Batch signal shape: (30, 2000)
Batch coefficients shape: (30, 12)


In [114]:
class ComplexFCNN(nn.Module):

    @nn.compact
    def __call__(self, x, deterministic, rngs={'dropout': key}):
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.2)(x, deterministic=deterministic)
        x = nn.Dense(64)(x)
        x = nn.relu(x)
        x = nn.Dense(12)(x) # needs to be 12 for 6 re values and 6 im values - last layer
        return x

# instance model
model = ComplexFCNN()

# set seed
key = random.PRNGKey(0)

input_shape = (batch_size, 2 * signal_length)

# initialize as 1s
variables = model.init(key, jnp.ones(input_shape), deterministic = True)

#forward pass to config
output = model.apply(variables, jnp.ones(input_shape), deterministic=True)

In [115]:
def loss_fn(params, model, inputs, true_coeffs, deterministic, rng_key):
    # Forward pass with the deterministic flag and PRNG key for dropout
    preds = model.apply({'params': params}, inputs, deterministic=deterministic, rngs={'dropout': rng_key})
    
    # Split predictions into real and imaginary parts
    preds_real, preds_imag = preds[:, :6], preds[:, 6:]
    print("preds_real", preds_real.shape)
    print("preds_imag", preds_imag.shape)
    true_real, true_imag = true_coeffs[:, :6], true_coeffs[:, 6:]
    
    # MSE loss for real and imaginary parts
    loss_real = jnp.mean((preds_real - true_real) ** 2)
    loss_imag = jnp.mean((preds_imag - true_imag) ** 2)
    
    return loss_real + loss_imag


In [116]:
# optimize using adam

from flax.training import train_state
import optax

root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, rng_key = jax.random.split(key=root_key, num=3)

class TrainState(train_state.TrainState):
    loss_fn = staticmethod(loss_fn)

#initialize training state and optimizer
tx = optax.adam(learning_rate = 1e-3)
state = TrainState.create(apply_fn=model.apply, params=variables['params'], tx = tx)

# Simplified training loop
for batch_signal, batch_coefficients in data_loader(dataset, batch_size):
    print(f"Batch signal shape: {batch_signal.shape}")  # Should be (batch_size, 2 * signal_length)
    print(f"Batch coefficients shape: {batch_coefficients.shape}")  # Should be (batch_size, 2 * 6)
    
    # Generate a new PRNG key for dropout for this batch
    rng_key, subkey = jax.random.split(rng_key)

    # Pass the PRNG key for dropout to the loss function
    loss, grads = jax.value_and_grad(state.loss_fn)(
        state.params, model, batch_signal, batch_coefficients, deterministic=False, rng_key=subkey
    )
    
    # Apply gradients to update model parameters
    state = state.apply_gradients(grads=grads)

Batch signal shape: (30, 2000)
Batch coefficients shape: (30, 12)
preds_real (30, 6)
preds_imag (30, 6)
