In [59]:
import flax.linen as nn
import jax.numpy as jnp
import os
import jax
from jax import random
import numpy as np
import pandas as pd

from flax.training import train_state
import optax

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
devices = jax.local_devices()

print("Backend Selected:", jax.lib.xla_bridge.get_backend().platform)
print("Detected Devices:", jax.devices())



root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, rng_key = jax.random.split(key=root_key, num=3)


Backend Selected: gpu
Detected Devices: [CudaDevice(id=0)]


  print("Backend Selected:", jax.lib.xla_bridge.get_backend().platform)


In [60]:
#making sample data
"""
num_samples = 30

# Test sample data
signal_data = [np.random.randn(1000) + 1j * np.random.randn(1000) for _ in range(num_samples)]
coefficients_data = [np.random.randn(6) + 1j * np.random.randn(6) for _ in range(num_samples)]
"""

'\nnum_samples = 30\n\n# Test sample data\nsignal_data = [np.random.randn(1000) + 1j * np.random.randn(1000) for _ in range(num_samples)]\ncoefficients_data = [np.random.randn(6) + 1j * np.random.randn(6) for _ in range(num_samples)]\n'

In [61]:
import pandas as pd
import numpy as np
import jax.numpy as jnp

# Define the path to the CSV files
label_file_path = '/home/houtlaw/iono-net/data/SAR_AF_ML_toyDataset_etc/radar_coeffs_csv/compl_ampls_20241026_201104.csv'
data_file_path = '/home/houtlaw/iono-net/data/SAR_AF_ML_toyDataset_etc/radar_coeffs_csv/nuStruct_withSpeckle_20241026_201052.csv'

# Function to convert complex strings (e.g., '5.7618732844527+1.82124094798357i') to complex numbers
def convert_to_complex(s):
    return complex(s.replace('i', 'j'))

# Load the CSV files using pandas and apply conversion to complex numbers
label_df = pd.read_csv(label_file_path, dtype=str)
data_df = pd.read_csv(data_file_path, dtype=str)

# Convert the string representations into complex values
label_matrix = label_df.applymap(convert_to_complex).to_numpy().T  # Transpose to get data points as rows
data_matrix = data_df.applymap(convert_to_complex).to_numpy().T    # Transpose to get data points as rows

print("Label Matrix Shape (after transpose):", label_matrix.shape)
print("Data Matrix Shape (after transpose):", data_matrix.shape)

# Split complex matrices into real and imaginary parts
def split_complex_to_imaginary(complex_array):
    return np.concatenate([complex_array.real, complex_array.imag], axis=-1)

# Now each row represents a data point with real and imaginary parts concatenated along the row
label_matrix_split = split_complex_to_imaginary(label_matrix)
data_matrix_split = split_complex_to_imaginary(data_matrix)

print("Label Matrix Split Shape:", label_matrix_split.shape)
print("Data Matrix Split Shape:", data_matrix_split.shape)

# Combine the signal (data) and coefficients (labels) into a dataset
# Each signal now has length 2000 (real + imaginary), coefficients have length 12 (real + imaginary)
dataset = list(zip(data_matrix_split, label_matrix_split))

# Data loader function
def data_loader(dataset, batch_size, shuffle=True):
    dataset_size = len(dataset)
    indices = np.arange(dataset_size)
    
    # Shuffle dataset if required
    if shuffle:
        np.random.shuffle(indices)
    
    # Loop over dataset and yield batches
    for start_idx in range(0, dataset_size, batch_size):
        end_idx = min(start_idx + batch_size, dataset_size)
        batch_indices = indices[start_idx:end_idx]
        
        # Extract the batch of signals and coefficients separately
        batch_signal = [dataset[i][0] for i in batch_indices]  # Signal of length 2000
        batch_coefficients = [dataset[i][1] for i in batch_indices]  # Coefficients of length 12
        
        # Convert the batch data to JAX arrays
        yield jnp.array(batch_signal), jnp.array(batch_coefficients)

# Example: process a batch of data
batch_size = 32
for batch_signal, batch_coefficients in data_loader(dataset, batch_size):
    print(f"Batch signal shape: {batch_signal.shape}")  # Should be (32, 2000)
    print(f"Batch coefficients shape: {batch_coefficients.shape}")  # Should be (32, 12)


  label_matrix = label_df.applymap(convert_to_complex).to_numpy().T  # Transpose to get data points as rows
  data_matrix = data_df.applymap(convert_to_complex).to_numpy().T    # Transpose to get data points as rows


Label Matrix Shape (after transpose): (10000, 6)
Data Matrix Shape (after transpose): (10000, 1441)
Label Matrix Split Shape: (10000, 12)
Data Matrix Split Shape: (10000, 2882)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
Batch signal shape: (32, 2882)


In [62]:
# data config
batch_size = 32
signal_length = batch_signal.shape[0]

In [66]:
import flax.linen as nn
import jax
import jax.numpy as jnp
import optax

class ComplexFCNN(nn.Module):
    @nn.compact
    def __call__(self, x, deterministic, rngs={'dropout': None}):  # Remove rng_key as a default here
        # First dense layer: input shape should match the length of the signal (2000)
        x = nn.Dense(128)(x)  # First fully connected layer
        x = nn.relu(x)
        
        # Apply dropout after the first layer
        x = nn.Dropout(0.2)(x, deterministic=deterministic)
        
        # Second dense layer
        x = nn.Dense(64)(x)   # Second fully connected layer
        x = nn.relu(x)
        
        # Final dense layer to output 12 values (6 real and 6 imaginary values for coefficients)
        x = nn.Dense(12)(x)  # Output layer with 12 units
        return x

# Instance of the model
model = ComplexFCNN()

# Set seed
key = jax.random.PRNGKey(0)

# Assuming batch_size and signal_length have been defined earlier
batch_size = 32
signal_length = 1441  # From your transposed data (2000 real+imaginary)

# Define input shape: (batch_size, 2000) because of real and imaginary concatenation
input_shape = (batch_size, 2 * signal_length)

# Initialize model variables with random inputs of ones (for testing initialization)
variables = model.init(key, jnp.ones(input_shape), deterministic=True)

# Forward pass to test the configuration and output shape
output = model.apply(variables, jnp.ones(input_shape), deterministic=True)

print(f"Model output shape: {output.shape}")


Model output shape: (32, 12)


In [67]:
def loss_fn(params, model, inputs, true_coeffs, deterministic, rng_key):
    # Forward pass with the deterministic flag and PRNG key for dropout
    preds = model.apply({'params': params}, inputs, deterministic=deterministic, rngs={'dropout': rng_key})
    
    # Split predictions into real and imaginary parts
    preds_real, preds_imag = preds[:, :6], preds[:, 6:]
    print("preds_real", preds_real.shape)
    print("preds_imag", preds_imag.shape)
    true_real, true_imag = true_coeffs[:, :6], true_coeffs[:, 6:]
    
    # MSE loss for real and imaginary parts
    loss_real = jnp.mean((preds_real - true_real) ** 2)
    loss_imag = jnp.mean((preds_imag - true_imag) ** 2)
    
    return loss_real + loss_imag


In [68]:
# optimize using adam


class TrainState(train_state.TrainState):
    loss_fn = staticmethod(loss_fn)

#initialize training state and optimizer
tx = optax.adam(learning_rate = 1e-3)
state = TrainState.create(apply_fn=model.apply, params=variables['params'], tx = tx)

# Simplified training loop
for batch_signal, batch_coefficients in data_loader(dataset, batch_size):
    print(f"Batch signal shape: {batch_signal.shape}")  # Should be (batch_size, 2 * signal_length)
    print(f"Batch coefficients shape: {batch_coefficients.shape}")  # Should be (batch_size, 2 * 6)
    
    # Generate a new PRNG key for dropout for this batch
    rng_key, subkey = jax.random.split(rng_key)

    # Pass the PRNG key for dropout to the loss function
    loss, grads = jax.value_and_grad(state.loss_fn)(
        state.params, model, batch_signal, batch_coefficients, deterministic=False, rng_key=subkey
    )
    
    # Apply gradients to update model parameters
    state = state.apply_gradients(grads=grads)

Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
preds_real (32, 6)
preds_imag (32, 6)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
preds_real (32, 6)
preds_imag (32, 6)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
preds_real (32, 6)
preds_imag (32, 6)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
preds_real (32, 6)
preds_imag (32, 6)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
preds_real (32, 6)
preds_imag (32, 6)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
preds_real (32, 6)
preds_imag (32, 6)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
preds_real (32, 6)
preds_imag (32, 6)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
preds_real (32, 6)
preds_imag (32, 6)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12)
preds_real (32, 6)
preds_imag (32, 6)
Batch signal shape: (32, 2882)
Batch coefficients shape: (32, 12