In [None]:
# Filter the multiwavelenght catalog to keep only the columns needed
emu_columns = [ 'EMU_ra_deg_cont', 'EMU_dec_deg_cont']

# desdr2_mag = ['DES_mag_auto_g_dered', 'DES_mag_auto_r_dered', 'DES_mag_auto_i_dered', 'DES_mag_auto_z_dered']
# desdr2_colors = ['DES_g_r_dered', 'DES_r_i_dered', 'DES_i_z_dered']

# Technically only need one magnitude to recover the rest from the colors
desy6gold_mag = ['DESY6_mag_auto_g_extcorr', 'DESY6_mag_auto_r_extcorr', 'DESY6_mag_auto_i_extcorr', 'DESY6_mag_auto_z_extcorr', 'DESY6_mag_auto_y_extcorr']
desy6gold_colors = ['DESY6_g_r_extcorr', 'DESY6_r_i_extcorr', 'DESY6_i_z_extcorr', 'DESY6_z_y_extcorr']
desy6gold_features = ['DESY6_dnf_z', 'DESY6_spread_model_g', 'DESY6_spread_model_r', 'DESY6_spread_model_i', 'DESY6_spread_model_z']

viking_mag = ['VKG_zAperMag3_ab_extcorr', 'VKG_jAperMag3_ab_extcorr', 'VKG_yAperMag3_ab_extcorr', 'VKG_ksAperMag3_ab_extcorr', 'VKG_hAperMag3_ab_extcorr']
viking_colors = ['VKG_z_y_am3_extcorr', 'VKG_y_j_am3_extcorr', 'VKG_j_h_am3_extcorr', 'VKG_h_ks_am3_extcorr']
viking_features = ['VKG_mergedClassStat']

catwise_mag = ['CAT_w1mpro_ab', 'CAT_w2mpro_ab']
catwise_colors = ['CAT_w1_w2_ab']

colors = desy6gold_colors + viking_colors + catwise_colors

## Quantization Error (QE) and Topographic Error (TE)

In [None]:
# Quantization Error (QE)
# Average distance of a data point to the nearest lattice node
# Measures how well the mapping fits the distribution of the data

def som_quantization_error(data, trained_som):
    '''
    Calculates the quantization error of a trained SOMNet object

    Args:
        data: NumPy array of input data.
        trainded_som: trained SimpSOM SOMNet object

    Returns:
        quantization error (float)
    '''
    # Convert to clean NumPy array with uniform dtype first
    data_np = np.asarray(data, dtype=np.float64)

    # Convert NumPy array to CuPy array for GPU processing
    data_cp = cp.array(data_np)

    # Find all BMU indices
    bmu_indices = trained_som.find_bmu_ix(data_cp)

    # Get the weights of the BMUs
    bmu_weights = cp.array([trained_som.nodes_list[int(bmu_idx)].weights for bmu_idx in bmu_indices])

    # Vectorized distance calculation
    distances = cp.linalg.norm(bmu_weights - data_cp, axis=1)
    total_distance = cp.sum(distances)

    quantization_error = float(total_distance / len(data))
    return quantization_error

In [None]:
# Topographic Error (TE)
# Proportion of data points whose BMU and second BMU are NOT neighbors
# Measures how well the shape of the data is preserved in the output space

def som_topographic_error(data, trained_som):
    """
    Computes topographic error ET for a SOM with hex topology using GPU (CuPy).
    
    Args:
        data:    The input data used to train the SOM (NumPy array).
        trained_som: A trained SOMNet object from the simpsom library.

    Returns:
        Topographic error (float between 0 and 1).
        Proportion of data points whose BMU and second BMU are NOT neighbors
    """
    # Convert data to CuPy arrays
    data_cp = cp.array(data)

    # Get all the nodes' weights from the trained SOM (GPU-enabled)
    weights_cp = cp.array([node.weights for node in trained_som.nodes_list])

    # Initialize total error
    total_error = 0

    # Find all BMU indices
    bmu_indices = trained_som.find_bmu_ix(data_cp)

    # Find all the second BMU indices
    som_dist = sps.distances.Distance(xp=np) # Initialize the distance object, as required by the documentation
    distances = som_dist.pairdist(data_cp, weights_cp, metric='euclidean')

    sorted_indices = cp.argsort(distances, axis=1)
    sbmu_indices = sorted_indices[:, 1]


    # Get the positions of the BMU and second BMU in the grid
    bmu_positions = cp.array([trained_som.nodes_list[int(bmu_idx)].pos for bmu_idx in bmu_indices])
    sbmu_positions = cp.array([trained_som.nodes_list[int(sbmu_idx)].pos for sbmu_idx in sbmu_indices])


    bmu_row, bmu_col = bmu_positions[:, 0], bmu_positions[:, 1]
    sbmu_row, sbmu_col = sbmu_positions[:, 0], sbmu_positions[:, 1]

    # Check if BMU and sBMU are neighbors
    row_neighbors = cp.abs(bmu_row - sbmu_row)
    col_neighbors = cp.abs(bmu_col - sbmu_col)

    not_neighbors = (row_neighbors > 1) | (col_neighbors > 1)

    number_not_neighbors = cp.sum(not_neighbors)

    # Compute the topographic error
    topographic_error = number_not_neighbors / data_cp.shape[0]

    return float(topographic_error)