In [None]:
import os
import scipy.io
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader, Subset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from complexPyTorch.complexLayers import ComplexLinear, ComplexConv2d
from complexPyTorch.complexFunctions import complex_relu, complex_avg_pool2d
import matplotlib.pyplot as plt

In [None]:
# Define the Dataset
class RASPNetDataset(Dataset):
    def __init__(self, data_dir, csv_file, split, limit):
        """
        Args:
            data_dir (string): Directory with all the data.
            csv_file (string): Path to the csv file with labels.
            split (string): 'train' or 'test'.
            limit (int): Number of samples to load.
        """
        self.data_dir = data_dir
        self.split = split
        self.limit = limit

        # Load labels
        labels_df = pd.read_csv(csv_file)
        self.labels = torch.from_numpy(labels_df[['R_idx', 'Az_idx', 'El_idx']].values[:limit].astype(np.float32)) # x, y, z
        self.labels = self.labels

        # Initialize lists to hold features
        features_list = []

        # Paths to real and imaginary folders
        real_dir = os.path.join(data_dir, split)
        imag_dir = os.path.join(data_dir, split)

        # Load all real and imag data into memory
        for i in range(1, limit + 1):
            # Construct file names
            real_file = f'real{i}.mat'
            imag_file = f'imag{i}.mat'

            real_path = os.path.join(real_dir, real_file)
            imag_path = os.path.join(imag_dir, imag_file)

            # Load .mat files
            try:
                real_data = scipy.io.loadmat(real_path)['Y_real']  # Adjust the key if different
                imag_data = scipy.io.loadmat(imag_path)['Y_imag']  # Adjust the key if different
            except KeyError as e:
                raise KeyError(f"Variable not found in {real_file} or {imag_file}: {e}")
            except FileNotFoundError as e:
                raise FileNotFoundError(f"File not found: {e}")

            # Verify the shape
            if real_data.shape != (5, 21, 16):
                raise ValueError(f"Unexpected shape for {real_file}: {real_data.shape}")
            if imag_data.shape != (5, 21, 16):
                raise ValueError(f"Unexpected shape for {imag_file}: {imag_data.shape}")

            # Concatenate real and imaginary parts
            features = np.concatenate([real_data, imag_data]).astype(np.float32)  # Shape: (3360,)
            features_list.append(features)

            # Optional: Print progress every 100 files
            if i % 5000 == 0 or i == limit:
                print(f'Loaded {i}/{limit} samples from {split} set.')

        # Convert list to tensor
        self.features = torch.from_numpy(np.array(features_list))  # Shape: (limit, 3360)

    def __len__(self):
        return self.limit

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

In [None]:
# Hyperparameters
SCENARIO = 'num29'
DATA_DIR = f'data/CVNN/{SCENARIO}'
TRAIN_CSV = f'data/CVNN/{SCENARIO}/train.csv'
TEST_CSV = f'data/CVNN/{SCENARIO}/test.csv'

# Create Datasets and DataLoaders
train_dataset = RASPNetDataset(data_dir=DATA_DIR, csv_file=TRAIN_CSV, split='train', limit=20000)
test_dataset = RASPNetDataset(data_dir=DATA_DIR, csv_file=TEST_CSV, split='test', limit=5000)

train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Training and test dataset global constants
if SCENARIO == 'num29':
    coord_tr = torch.tensor([10851, 215, -5.4], device=device, dtype=torch.float32)
    coord_ts = torch.tensor([10851, 215, -5.4], device=device, dtype=torch.float32)
if SCENARIO == 'num35':
    coord_tr = torch.tensor([11381, 215, -0.95], device=device, dtype=torch.float32)
    coord_ts = torch.tensor([11381, 215, -0.95], device=device, dtype=torch.float32)
if SCENARIO == 'num60':
    coord_tr = torch.tensor([11073, 215, -5.3], device=device, dtype=torch.float32)
    coord_ts = torch.tensor([11073, 215, -5.3], device=device, dtype=torch.float32)
if SCENARIO == 'num62':
    coord_tr = torch.tensor([11471, 215, -5.6], device=device, dtype=torch.float32)
    coord_ts = torch.tensor([11471, 215, -5.6], device=device, dtype=torch.float32)
if SCENARIO == 'num76':
    coord_tr = torch.tensor([11388, 215, -6.15], device=device, dtype=torch.float32)
    coord_ts = torch.tensor([11388, 215, -6.15], device=device, dtype=torch.float32)

rng_res_tr = 59.9585 / 2          # Range resolution
az_step_tr = 0.4                  # Azimuth step size
el_step_tr = 0.01                 # Elevation step size
scale_tr = [rng_res_tr, az_step_tr, el_step_tr]
rng_res_ts = 59.9585 / 2          # Range resolution
az_step_ts = 0.4                  # Azimuth step size
el_step_ts = 0.01                 # Elevation step size
scale_ts = [rng_res_ts, az_step_ts, el_step_ts]

# Define the average Euclidean distance function
def average_euclidean_distance(pred, target):
    """
    Compute the average Euclidean distance between predictions and targets.

    Parameters:
    - pred (Tensor): Predicted Cartesian coordinates of shape (batch_size, 3).
    - target (Tensor): True Cartesian coordinates of shape (batch_size, 3).

    Returns:
    - float: The average Euclidean distance over the batch.
    """
    return torch.norm(pred - target, p=2, dim=1).mean().item()

# Define the Spher2Cart_1D_torch function as previously modified
def Spher2Cart_1D_torch(spherical, scale, coord):
    """
    Convert spherical coordinates to Cartesian coordinates.

    Parameters:
    - spherical: Tensor of shape (batch_size, 3) containing [range, azimuth, elevation].
    - scale: Tensor or list containing [rng_res, az_step, el_step].
    - coord: Tensor or list containing [x0, y0, z0] to shift the coordinates.

    Returns:
    - cartesian: Tensor of shape (batch_size, 3) containing [x, y, z].
    """
    # If scale and coord are lists or numpy arrays, convert them to tensors
    if not isinstance(scale, torch.Tensor): scale = torch.tensor(scale, device=spherical.device, dtype=torch.float32)
    else: scale = scale.to(device=spherical.device, dtype=torch.float32)
    if not isinstance(coord, torch.Tensor): coord = torch.tensor(coord, device=spherical.device, dtype=torch.float32)
    else: coord = coord.to(device=spherical.device, dtype=torch.float32)

    # Clone and detach to avoid modifying the original tensors
    scale = scale.clone().detach(); coord = coord.clone().detach()

    scaled = spherical * scale + coord  # Apply scaling and shifting
    r = scaled[:, 0]; az_deg = scaled[:, 1]; el_deg = scaled[:, 2]

    az = torch.deg2rad(az_deg); el = torch.deg2rad(el_deg)

    hyp = torch.cos(el) * r
    x = torch.cos(az) * hyp
    y = -torch.sin(az) * hyp
    z = torch.sin(el) * r

    cartesian = torch.stack((x, y, z), dim=1)
    return cartesian

In [None]:
def fft_based_hilbert_transform(real_features):
    """
    Apply the FFT-based Hilbert transform to the real features to obtain the imaginary part.

    Parameters:
    - real_features (torch.Tensor): The input real features.

    Returns:
    - transformed_imag (torch.Tensor): The transformed imaginary features obtained from the Hilbert transform.
    """
    # Perform FFT
    fft_result = torch.fft.fft(real_features, dim=-1)

    # Get the number of samples and create a tensor to hold the phase shifts
    N = real_features.shape[-1]
    phase_shift = torch.zeros_like(fft_result)

    # Apply a -90 degree phase shift for positive frequencies (1 to N/2 - 1)
    # and a +90 degree phase shift for negative frequencies (N/2 + 1 to N - 1)
    if N % 2 == 0:
        # Even number of samples
        phase_shift[..., 1:N//2] = -1j  # Positive frequencies (excluding Nyquist)
        phase_shift[..., N//2+1:] = 1j  # Negative frequencies
    else:
        # Odd number of samples
        phase_shift[..., 1:(N+1)//2] = -1j  # Positive frequencies
        phase_shift[..., (N+1)//2:] = 1j   # Negative frequencies

    # Apply phase shift and perorm inverse FFT
    shifted_fft_result = fft_result * phase_shift
    transformed_imag = torch.fft.ifft(shifted_fft_result, dim=-1).real

    return transformed_imag


# Custom loss function for Analytic Neural Network
def custom_loss(outputs, target, real_features, imag_features):
    # Implementing hilbert consistency penalty + custom loss function
    transformed_imag = fft_based_hilbert_transform(real_features)
    consistency_penalty = nn.functional.mse_loss(transformed_imag, imag_features)
    beta = 1e-3 # tradeoff parameter

    return nn.MSELoss()(outputs, target) + beta*consistency_penalty

## Fully Connected Networks

In [None]:
# # Define the neural network models
# class SteinmetzNetwork(nn.Module):
#     def __init__(self, dN, k, lN):
#         super(SteinmetzNetwork, self).__init__()
#         self.real_net = nn.Sequential(nn.Linear(dN, lN//2), nn.ReLU(), nn.Linear(lN//2, lN//2), nn.ReLU())
#         self.imag_net = nn.Sequential(nn.Linear(dN, lN//2), nn.ReLU(), nn.Linear(lN//2, lN//2), nn.ReLU())
#         self.regressor = nn.Sequential(nn.Linear(lN, k))

#     def forward(self, real, imag):
#         real_features = self.real_net(real)
#         imag_features = self.imag_net(imag)
        
#         # Mean centering features as last step before concatenation
#         # real_features = real_features - real_features.mean(dim=0, keepdim=True)
#         imag_features = imag_features - imag_features.mean(dim=0, keepdim=True)
        
#         combined = torch.cat((real_features, imag_features), dim=1)
#         output = self.regressor(combined)
#         return output, real_features, imag_features

# class NeuralNetwork(nn.Module):
#     def __init__(self, dN, k, lN):
#         super(NeuralNetwork, self).__init__()
#         self.net = nn.Sequential(nn.Linear(2*dN, lN//2), nn.ReLU(), nn.Linear(lN//2, lN), nn.ReLU(), nn.Linear(lN, k))

#     def forward(self, real, imag):
#         input = torch.cat((real, imag), dim=1)
#         output = self.net(input)
#         return output

# class ComplexNeuralNetwork(nn.Module):
#     def __init__(self, dN, k, lN):
#         super(ComplexNeuralNetwork, self).__init__()
#         self.fc1 = ComplexLinear(dN, lN//2)
#         self.fc2 = ComplexLinear(lN//2, lN)
#         self.fc3 = ComplexLinear(lN, k)

#     def forward(self, real, imag):
#         complex_tensor = torch.stack((real, imag), dim=-1)
#         x = torch.view_as_complex(complex_tensor)
#         x = complex_relu(self.fc1(x))
#         x = complex_relu(self.fc2(x))
#         x = self.fc3(x)
#         output = torch.sqrt(torch.real(x)**2 + torch.imag(x)**2)
#         return output

## Convolutional Neural Networks

In [None]:
# Define the neural network models
class SteinmetzNetwork(nn.Module):
    def __init__(self, channels, height, width, k, lN):
        super(SteinmetzNetwork, self).__init__()
        self.conv_real1 = nn.Conv2d(channels, 32, kernel_size=3, padding=1)
        self.conv_real2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv_imag1 = nn.Conv2d(channels, 32, kernel_size=3, padding=1)
        self.conv_imag2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        flattened_size = 64 * (height // 4) * (width // 4)
        self.fc1 = nn.Linear(flattened_size*2, lN)
        self.fc2 = nn.Linear(lN, k)
        
    def forward(self, real, imag):
        real = F.avg_pool2d(F.relu(self.conv_real1(real)), 2, 2)
        real = F.avg_pool2d(F.relu(self.conv_real2(real)), 2, 2)
        real = torch.flatten(real, 1)
        imag = F.avg_pool2d(F.relu(self.conv_imag1(imag)), 2, 2)
        imag = F.avg_pool2d(F.relu(self.conv_imag2(imag)), 2, 2)
        imag = torch.flatten(imag, 1)
        imag = imag - imag.mean(dim=0, keepdim=True)
        x = torch.cat((real, imag), dim=1)
        x = F.relu(self.fc1(x))
        output = self.fc2(x)
        return output, real, imag


class NeuralNetwork(nn.Module):
    def __init__(self, channels, height, width, k, lN):
        super(NeuralNetwork, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=channels * 2, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        flattened_size = 64 * (height // 4) * (width // 4)
        self.fc1 = nn.Linear(flattened_size, lN)
        self.fc2 = nn.Linear(lN, k)
        
    def forward(self, real, imag):
        x = torch.cat((real, imag), dim=1)
        x = F.avg_pool2d(F.relu(self.conv1(x)), 2, 2)
        x = F.avg_pool2d(F.relu(self.conv2(x)), 2, 2)
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        output = self.fc2(x)
        return output

class ComplexNeuralNetwork(nn.Module):
    def __init__(self, channels, height, width, k, lN):
        super(ComplexNeuralNetwork, self).__init__()
        self.conv1 = ComplexConv2d(in_channels=channels, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = ComplexConv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        flattened_size = 64 * (height // 4) * (width // 4)
        self.fc1 = ComplexLinear(flattened_size, lN)
        self.fc2 = ComplexLinear(lN, k)
        
    def forward(self, real, imag):
        complex_tensor = torch.view_as_complex(torch.stack((real, imag), dim=-1))
        x = complex_avg_pool2d(complex_relu(self.conv1(complex_tensor)), 2, 2)
        x = complex_avg_pool2d(complex_relu(self.conv2(x)), 2, 2)
        x = x.view(-1,64*5*4)
        x = complex_relu(self.fc1(x))
        x = self.fc2(x)
        output = torch.sqrt(torch.real(x)**2 + torch.imag(x)**2)
        return output

In [None]:
# Initialize the loss function, and optimizer
epochs = 70
iterations = 5
channels = 5; height = 21; width = 16
dN = channels * height * width
k = 2
lN = 128  # Latent Dimensionality
criterion = nn.MSELoss()

# Move model to GPU if available
train_error_rvnn = np.zeros((iterations, epochs))
train_error_cvnn = np.zeros((iterations, epochs))
train_error_steinmetz = np.zeros((iterations, epochs))
train_error_analytic = np.zeros((iterations, epochs))
test_error_rvnn = np.zeros((iterations, epochs))
test_error_cvnn = np.zeros((iterations, epochs))
test_error_steinmetz = np.zeros((iterations, epochs))
test_error_analytic = np.zeros((iterations, epochs))

# Initialize the model and other components as before
for iter in range(iterations):
    # Initialize models and optimizers
    model_rvnn = NeuralNetwork(channels=channels, height=height, width=width, k=k, lN=lN).to(device)
    model_cvnn = ComplexNeuralNetwork(channels=channels, height=height, width=width, k=k, lN=lN).to(device)
    model_steinmetz = SteinmetzNetwork(channels=channels, height=height, width=width, k=k, lN=lN).to(device)
    model_analytic = SteinmetzNetwork(channels=channels, height=height, width=width, k=k, lN=lN).to(device)

    optimizer_rvnn = optim.Adam(model_rvnn.parameters(), lr=5e-2)
    optimizer_cvnn = optim.Adam(model_cvnn.parameters(), lr=5e-2)
    optimizer_steinmetz = optim.Adam(model_steinmetz.parameters(), lr=5e-2)
    optimizer_analytic = optim.Adam(model_analytic.parameters(), lr=5e-2)

    print(f'RVNN params: {sum(p.numel() for p in model_rvnn.parameters())}, \
        CVNN params: {sum(p.numel() for p in model_cvnn.parameters())}, \
        Steinmetz params: {sum(p.numel() for p in model_steinmetz.parameters())}, \
        Analytic params: {sum(p.numel() for p in model_analytic.parameters())}')

    # Training Loop
    for epoch in range(epochs):
        # Set models to training mode
        model_rvnn.train(); model_cvnn.train(); model_steinmetz.train(); model_analytic.train()
        train_distances_rvnn, train_distances_cvnn = [], []
        train_distances_steinmetz, train_distances_analytic = [], []

        for batch_idx, (features, labels) in enumerate(train_loader):
            features, labels = features.to(device), labels.to(device)

            # Forward pass
            outputs_rvnn = model_rvnn(features[:,:channels], features[:,channels:])
            outputs_cvnn = model_cvnn(features[:,:channels], features[:,channels:])
            outputs_steinmetz, _, _ = model_steinmetz(features[:,:channels], features[:,channels:])
            outputs_analytic, real_feat, imag_feat = model_analytic(features[:,:channels], features[:,channels:])

            # Neural network only predicts range and azimuth
            outputs_rvnn = torch.cat((outputs_rvnn, torch.unsqueeze(labels[:,2], dim=1)), dim=1)
            outputs_cvnn = torch.cat((outputs_cvnn, torch.unsqueeze(labels[:,2], dim=1)), dim=1)
            outputs_steinmetz = torch.cat((outputs_steinmetz, torch.unsqueeze(labels[:,2], dim=1)), dim=1)
            outputs_analytic = torch.cat((outputs_analytic, torch.unsqueeze(labels[:,2], dim=1)), dim=1)

            # Transform outputs and labels to Cartesian coordinates
            pred_cart_rvnn = Spher2Cart_1D_torch(outputs_rvnn, scale_tr, coord_tr)
            pred_cart_cvnn = Spher2Cart_1D_torch(outputs_cvnn, scale_tr, coord_tr)
            pred_cart_steinmetz = Spher2Cart_1D_torch(outputs_steinmetz, scale_tr, coord_tr)
            pred_cart_analytic = Spher2Cart_1D_torch(outputs_analytic, scale_tr, coord_tr)

            labels_cart = Spher2Cart_1D_torch(labels, scale_tr, coord_tr)

            # Compute loss between transformed coordinates
            loss_rvnn = criterion(pred_cart_rvnn, labels_cart)
            loss_cvnn = criterion(pred_cart_cvnn, labels_cart)
            loss_steinmetz = criterion(pred_cart_steinmetz, labels_cart)
            loss_analytic = custom_loss(pred_cart_analytic, labels_cart, real_feat, imag_feat)

            # Backward and optimize
            optimizer_rvnn.zero_grad(); optimizer_cvnn.zero_grad(); optimizer_steinmetz.zero_grad(); optimizer_analytic.zero_grad()
            loss_rvnn.backward(); loss_cvnn.backward(); loss_steinmetz.backward(); loss_analytic.backward()
            optimizer_rvnn.step(); optimizer_cvnn.step(); optimizer_steinmetz.step(); optimizer_analytic.step()

            # Compute and store Euclidean distances
            distance_rvnn = average_euclidean_distance(pred_cart_rvnn, labels_cart)
            distance_cvnn = average_euclidean_distance(pred_cart_cvnn, labels_cart)
            distance_steinmetz = average_euclidean_distance(pred_cart_steinmetz, labels_cart)
            distance_analytic = average_euclidean_distance(pred_cart_analytic, labels_cart)

            train_distances_rvnn.append(distance_rvnn)
            train_distances_cvnn.append(distance_cvnn)
            train_distances_steinmetz.append(distance_steinmetz)
            train_distances_analytic.append(distance_analytic)

        # Evaluation on Test Set
        model_rvnn.eval(); model_cvnn.eval(); model_steinmetz.eval(); model_analytic.eval()
        test_distances_rvnn, test_distances_cvnn = [], []
        test_distances_steinmetz, test_distances_analytic = [], []

        with torch.no_grad():
            for features, labels in test_loader:
                features, labels = features.to(device), labels.to(device)

                # Forward pass
                outputs_rvnn = model_rvnn(features[:,:channels], features[:,channels:])
                outputs_cvnn = model_cvnn(features[:,:channels], features[:,channels:])
                outputs_steinmetz, _, _ = model_steinmetz(features[:,:channels], features[:,channels:])
                outputs_analytic, _, _ = model_analytic(features[:,:channels], features[:,channels:])

                # Neural network only predicts range and azimuth
                outputs_rvnn = torch.cat((outputs_rvnn, torch.unsqueeze(labels[:,2], dim=1)), dim=1)
                outputs_cvnn = torch.cat((outputs_cvnn, torch.unsqueeze(labels[:,2], dim=1)), dim=1)
                outputs_steinmetz = torch.cat((outputs_steinmetz, torch.unsqueeze(labels[:,2], dim=1)), dim=1)
                outputs_analytic = torch.cat((outputs_analytic, torch.unsqueeze(labels[:,2], dim=1)), dim=1)

                # Transform outputs and labels to Cartesian coordinates
                pred_cart_rvnn = Spher2Cart_1D_torch(outputs_rvnn, scale_ts, coord_ts)
                pred_cart_cvnn = Spher2Cart_1D_torch(outputs_cvnn, scale_ts, coord_ts)
                pred_cart_steinmetz = Spher2Cart_1D_torch(outputs_steinmetz, scale_ts, coord_ts)
                pred_cart_analytic = Spher2Cart_1D_torch(outputs_analytic, scale_ts, coord_ts)

                labels_cart = Spher2Cart_1D_torch(labels, scale_ts, coord_ts)

                # Compute loss between transformed coordinates
                loss_rvnn = criterion(pred_cart_rvnn, labels_cart)
                loss_cvnn = criterion(pred_cart_cvnn, labels_cart)
                loss_steinmetz = criterion(pred_cart_steinmetz, labels_cart)
                loss_analytic = criterion(pred_cart_analytic, labels_cart)

                # Compute and store Euclidean distances
                distance_rvnn = average_euclidean_distance(pred_cart_rvnn, labels_cart)
                distance_cvnn = average_euclidean_distance(pred_cart_cvnn, labels_cart)
                distance_steinmetz = average_euclidean_distance(pred_cart_steinmetz, labels_cart)
                distance_analytic = average_euclidean_distance(pred_cart_analytic, labels_cart)

                test_distances_rvnn.append(distance_rvnn)
                test_distances_cvnn.append(distance_cvnn)
                test_distances_steinmetz.append(distance_steinmetz)
                test_distances_analytic.append(distance_analytic)

        # Aggregate and store Average Euclidean Distances
        train_error_rvnn[iter, epoch] = np.mean(train_distances_rvnn)
        train_error_cvnn[iter, epoch] = np.mean(train_distances_cvnn)
        train_error_steinmetz[iter, epoch] = np.mean(train_distances_steinmetz)
        train_error_analytic[iter, epoch] = np.mean(train_distances_analytic)

        test_error_rvnn[iter, epoch] = np.mean(test_distances_rvnn)
        test_error_cvnn[iter, epoch] = np.mean(test_distances_cvnn)
        test_error_steinmetz[iter, epoch] = np.mean(test_distances_steinmetz)
        test_error_analytic[iter, epoch] = np.mean(test_distances_analytic)

        # Print progress
        print(f'Iteration [{iter}/{iterations}], Epoch [{epoch}/{epochs}], RVNN MSE: {test_error_rvnn[iter,epoch]:.4f}, CVNN MSE: {test_error_cvnn[iter,epoch]:.4f}, Steinmetz MSE: {test_error_steinmetz[iter,epoch]:.4f}, Analytic MSE: {test_error_analytic[iter,epoch]:.4f}')

In [None]:
# Step 2: Compute means and 95% confidence intervals
def compute_mean_and_CI(data):
    mean = np.mean(data, axis=0)
    std_error = np.std(data, axis=0) / np.sqrt(data.shape[0])  # Standard error of the mean
    stdev = np.std(data, axis=0)
    ci = 1.96 * std_error  # 95% CI for a normal distribution
    return mean, ci, stdev

X = test_error_rvnn[:,:]; Y = test_error_cvnn[:,:]
Z = test_error_steinmetz[:,:]; A = test_error_analytic[:,:];
mean_X, ci_X, stdev_X = compute_mean_and_CI(X)
mean_Y, ci_Y, stdev_Y = compute_mean_and_CI(Y)
mean_Z, ci_Z, stdev_Z = compute_mean_and_CI(Z)
mean_A, ci_A, stdev_A = compute_mean_and_CI(A)

print(mean_X[-1], mean_Y[-1], mean_Z[-1], mean_A[-1])
print(stdev_X[-1], stdev_Y[-1], stdev_Z[-1], stdev_A[-1])

# Step 3: Plot the results
epochs_all = list(range(1, X.shape[1] + 1))
plt.figure(figsize=(9, 6))

# Plotting for X
plt.fill_between(epochs_all, mean_X - ci_X, mean_X + ci_X, color='blue', alpha=0.1, zorder=0)
plt.plot(epochs_all, mean_X, 'b-', label="RVNN", zorder=20)

# Plotting for Y
plt.fill_between(epochs_all, mean_Y - ci_Y, mean_Y + ci_Y, color='red', alpha=0.2, zorder=5)
plt.plot(epochs_all, mean_Y, 'r-', label="CVNN", zorder=25)

# Plotting for Z
plt.fill_between(epochs_all, mean_Z - ci_Z, mean_Z + ci_Z, color='orange', alpha=0.3, zorder=10)
plt.plot(epochs_all, mean_Z, '-', color='orange', label="Steinmetz Neural Network", zorder=30)

# Plotting for A
plt.fill_between(epochs_all, mean_A - ci_A, mean_A + ci_A, color='green', alpha=0.3, zorder=15)
plt.plot(epochs_all, mean_A, 'g-', label="Analytic Neural Network", zorder=35)

# Additional plot settings
plt.xlabel("Number of Epochs", fontsize=16)
plt.ylabel("Average Euclidean Distance (m)", fontsize=16)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.yscale('linear')
plt.legend(prop={'size': 15},loc='center right',framealpha=0.7).set_zorder(50)
plt.grid(True)
plt.savefig(f'Results/RASPNet_{SCENARIO}_epochs.png', bbox_inches='tight')
plt.show()

np.savetxt(f'Results/test_error_steinmetz.csv', test_error_steinmetz, delimiter=',')
np.savetxt(f'Results/test_error_analytic.csv', test_error_analytic, delimiter=',')
np.savetxt(f'Results/test_error_rvnn.csv', test_error_rvnn, delimiter=',')
np.savetxt(f'Results/test_error_cvnn.csv', test_error_cvnn, delimiter=',')