In [1]:
from astropy.table import Table
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from minisom import MiniSom
import pickle
import os

In [2]:
points = [i / 10 for i in range(1, 51)]

for i in range(1, 101):
    for j in range(20, 45):
        b = {}
        b["dim"] = 30
        b["rate"] = i / 1000
        b["sigma"] = j
        b["distance"] = "euclidean"
        b["type"] = "noisy_lupcolor_shallow"
        b["iter"] = 1
        with open(f"D:/SPHEREx_SOM/record/3rd_grid_search/noisy_lupcolor_shallow/SOM/{b['type']}_{b['dim']}_{b['sigma']}_{b['rate']}_{b['iter']}.pkl",'wb') as handle:
            pickle.dump(b, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [3]:
df_L3_info = Table.read('../dataset/L3_COSMOS2020_Richard_RefCat_2023DEC4_info.fits')
df_L3_info = df_L3_info.to_pandas().sort_values(by = "cosmos_id")

fname = 'D:/SPHEREx_SOM/dataset/sphx_refcat/Noiseless_phot_cosmos_nolines_refcat30k.txt'
data_noiseless = np.loadtxt(fname)[:,3:]

fname = 'D:/SPHEREx_SOM/dataset/sphx_refcat/NoisySphx_shallow_nolines_refcat30k.txt'
data_all = np.loadtxt(fname)[:,3:]

fname = 'D:/SPHEREx_SOM/dataset/sphx_refcat/NoisySphx_deep_nolines_refcat30k.txt'
data_deep = np.loadtxt(fname)[:,3:]

fname = 'D:/SPHEREx_SOM/dataset/sphx_refcat/SPHEREx_1sigma_noise.txt'
data_1sig  = np.loadtxt(fname, skiprows=1)
wl = data_1sig[:,0]
sigma_all = data_1sig[:,1]
sigma_deep = data_1sig[:,2]

In [4]:
def grid_search(data, err, folder):
    n = 0
    files = os.listdir(folder)
    for i in files:
        n += 1
        with open(folder + i, "rb") as fh:
            record = pickle.load(fh)
        if "som" not in record:
            som = MiniSom(record["dim"], record["dim"], data.shape[1], sigma = record["sigma"], learning_rate = record["rate"], activation_distance = record["distance"])
            som.random_weights_init(data)
            som.train(data, err, num_iteration = record["iter"], use_epochs = True) 
            record["topo_err"] = som.topographic_error(data)
            record["quan_err"] = som.quantization_error(data)
            record["som"] = som
            record["band_cut"] = len(data[0])
            record["mag_cut"] = len(data)
            record["preproc"] = "default"
            record["topo_err"] = som.topographic_error(data)
            record["quan_err"] = som.quantization_error(data)
            print(f"Progress: {round(n / len(files) * 100, 2)}%, param: {[record["dim"], record["sigma"], record["rate"]]}", end = "\r")
            try:
                with open(folder + i, 'wb') as handle:
                    pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)
            except KeyboardInterrupt:
                print('KeyboardInterrupt caught, data saved.')
        else:
            pass

In [11]:
def parameter_space(save, folder):
    files = os.listdir(folder)
    rates = []
    sigmas = []
    for i in files:
        with open(folder + i, "rb") as fh:
            record = pickle.load(fh)
        if record["sigma"] < 20:
            continue
        rates.append(record["rate"])
        sigmas.append(record["sigma"])
            
    x_tick_f = list(set(rates))
    x_tick_f.sort()
    y_tick_f = list(set(sigmas))
    y_tick_f.sort()
    x_tick_i = [i + 0.5 for i in range(0, len(x_tick_f))]
    y_tick_i = [i + 0.5 for i in range(0, len(y_tick_f))]
    
    quan_map30 = np.zeros((len(y_tick_f), len(x_tick_f)))
    topo_map30 = np.zeros((len(y_tick_f), len(x_tick_f)))
    quan_map50 = np.zeros((len(y_tick_f), len(x_tick_f)))
    topo_map50 = np.zeros((len(y_tick_f), len(x_tick_f)))
    quan_map70 = np.zeros((len(y_tick_f), len(x_tick_f)))
    topo_map70 = np.zeros((len(y_tick_f), len(x_tick_f)))
    quan_map90 = np.zeros((len(y_tick_f), len(x_tick_f)))
    topo_map90 = np.zeros((len(y_tick_f), len(x_tick_f)))

    x_index = dict(zip(x_tick_f, [i - 0.5 for i in x_tick_i]))
    y_index = dict(zip(y_tick_f, [i - 0.5 for i in y_tick_i]))


    for i in files:
        with open(folder + i, "rb") as fh:
            record = pickle.load(fh)
        if record["sigma"] < 20:
            continue
        if record["dim"] == 30:
            topo_map30[int(y_index[record["sigma"]]), int(x_index[record["rate"]])] = record["topo_err"]
            quan_map30[int(y_index[record["sigma"]]), int(x_index[record["rate"]])] = record["quan_err"]

        elif record["dim"] == 50:
            topo_map50[int(y_index[record["sigma"]]), int(x_index[record["rate"]])] = record["topo_err"]
            quan_map50[int(y_index[record["sigma"]]), int(x_index[record["rate"]])] = record["quan_err"]

        elif record["dim"] == 70:
            topo_map70[int(y_index[record["sigma"]]), int(x_index[record["rate"]])] = record["topo_err"]
            quan_map70[int(y_index[record["sigma"]]), int(x_index[record["rate"]])] = record["quan_err"]

        elif record["dim"] == 90:
            topo_map90[int(index[record["sigma"]]), int(index[record["rate"]])] = record["topo_err"]
            quan_map90[int(index[record["sigma"]]), int(index[record["rate"]])] = record["quan_err"]

    if np.sum(quan_map30) == 0:
        quan_map30 = np.ones((len(y_tick_f), len(x_tick_f)))
    if np.sum(quan_map50) == 0:
        quan_map50 = np.ones((len(y_tick_f), len(x_tick_f)))
    if np.sum(quan_map70) == 0:
        quan_map70 = np.ones((len(y_tick_f), len(x_tick_f)))
    if np.sum(quan_map90) == 0:
        quan_map90 = np.ones((len(y_tick_f), len(x_tick_f)))
        
    topo_maps = [topo_map30, topo_map50, topo_map70, topo_map90]
    quan_maps = [quan_map30, quan_map50, quan_map70, quan_map90]

    x_tick_f = np.round(x_tick_f, 3)
    
    plt.close()
    plt.figure(figsize = (30, 20))

    for i in range(1, 5):
        dim = (i * 2 + 1) * 10
        plt.subplot(220 + i)
        plt.hist(topo_maps[i - 1].reshape(-1), bins = 20)
        plt.xlabel("Topological Error", fontsize=20)
        plt.ylabel("# of Glaxies", fontsize=20)
    
    if save:
        plt.savefig(save + f"{record['type']}_{record['band_cut']}_topo_err_dist.jpg", bbox_inches = "tight")    
    plt.show()

    plt.close()
    plt.figure(figsize = (30, 20))
    for i in range(1, 5):
        dim = (i * 2 + 1) * 10
        plt.subplot(220 + i)
        plt.hist(np.log10(quan_maps[i - 1]).reshape(-1), bins = 20)
        plt.xlabel("log(Quantization Error)", fontsize=20)
        plt.ylabel("# of Glaxies", fontsize=20)
        plt.yscale("log")
        
    if save:
        plt.savefig(save + f"{record['type']}_{record['band_cut']}_quan_err_dist.jpg", bbox_inches = "tight")  
    plt.show()
    
    plt.close()
    plt.figure(figsize = (30, 20))
    for i in range(1, 5):
        dim = (i * 2 + 1) * 10
        plt.subplot(220 + i)
        plt.pcolor(topo_maps[i - 1], cmap = 'YlGn')
        plt.title(f"Topological Error (Dim: {dim}X{dim} / Cut: both)", fontsize=20)
        plt.xlabel("Rate", fontsize=15)
        plt.ylabel("Sigma", fontsize=15)
        plt.xticks(x_tick_i[::5], x_tick_f[::5], fontsize=8)
        plt.yticks(y_tick_i[::5], y_tick_f[::5], fontsize=8)
        plt.colorbar().ax.tick_params(labelsize=15)

    if save:
        plt.savefig(save + f"{record['type']}_{record['band_cut']}_topo_err_para_space.jpg", bbox_inches = "tight")  
    plt.show()
    
    plt.close()
    plt.figure(figsize = (30, 20))
    for i in range(1, 5):
        dim = (i * 2 + 1) * 10
        plt.subplot(220 + i)
        plt.pcolor(np.log10(quan_maps[i - 1]), cmap = 'YlGn')
        plt.title(f"Quantizationl Error (Dim: {dim}X{dim} / Cut: both)", fontsize=20)
        plt.xlabel("Rate", fontsize=15)
        plt.ylabel("Sigma", fontsize=15)
        plt.xticks(x_tick_i[::5], x_tick_f[::5], fontsize=8)
        plt.yticks(y_tick_i[::5], y_tick_f[::5], fontsize=8)
        plt.colorbar().ax.tick_params(labelsize=15)

    if save != False:
        plt.savefig(save + f"{record['type']}_{record['band_cut']}_quan_err_para_space.jpg", bbox_inches = "tight")  
    plt.show()

In [14]:
def visual_map(folder, dim, sigma, rate, iteration, data, err, info_data, save = False):
    file = folder + f"_{dim}_{sigma}_{rate}_{iteration}.pkl"
    with open(file, "rb") as fh:
        record = pickle.load(fh)
    som = record["som"]
    print(f"Topological error: {som.topographic_error(data)}")
    print(f"Quantization error: {som.quantization_error(data)}")
    print(f"Topological error: {record["topo_err"]}")
    print(f"Quantization error: {record["quan_err"]}")        
    if "z" in record:
        density_map = record["density"]
        magnitude_map = record["magnitude"]
        z_map = record["z"]
        
        %matplotlib inline
        plt.close()
        plt.figure(figsize = (30, 20))
        plt.subplot(221)
        plt.pcolor(som.distance_map().T, cmap='YlGn')  # plotting the distance map as background
        plt.colorbar()
        plt.title("Distance map (U-matrix)")
    
        plt.subplot(222)
        plt.pcolor(density_map.T, cmap='YlGn')
        plt.colorbar()
        plt.title("Density Map")
    
        plt.subplot(223)
        plt.pcolor(magnitude_map.T, cmap="plasma" + "_r")
        plt.colorbar()
        plt.title("HSC I Magnitude Map")
    
        plt.subplot(224)
        plt.pcolor(z_map.T, cmap="plasma")
        plt.colorbar()
        plt.title("True Redshift Map")
        if save != False:
            plt.savefig(save + f"{record['type']}_{len(data[0])}_{dim}_{sigma}_{rate}_{record['iter']}.jpg", bbox_inches = "tight")
        
        plt.show()
    else:
        density_map = np.zeros(dim ** 2).reshape(dim, dim)
        magnitude_map = np.zeros(dim ** 2).reshape(dim, dim)
        magnitude_map.fill(np.nan)
        z_map = np.zeros(dim ** 2).reshape(dim, dim)
        z_map.fill(np.nan)
        
        labels_map_1 = som.labels_map(data, err, [0] * len(data))
        labels_map_2 = som.labels_map(data, err, tuple(map(tuple, np.concatenate((np.expand_dims(info_data["HSC_i_MAG"].values, axis = 1), np.expand_dims(info_data["z_true"].values, axis = 1)), axis = -1))))
    
        for i in labels_map_1.keys():
            density_map[int(list(i)[0]), int(list(i)[1])] = labels_map_1[i][0]
        
        for i in labels_map_2.keys():
            properties = np.mean(np.array(list(labels_map_2[i].keys())), axis = 0)
            magnitude_map[int(list(i)[0]), int(list(i)[1])] = properties[0]
            z_map[int(list(i)[0]), int(list(i)[1])] = properties[1]
    
        %matplotlib inline
        plt.close()
        plt.figure(figsize = (30, 20))
        plt.subplot(221)
        plt.pcolor(som.distance_map().T, cmap='YlGn')  # plotting the distance map as background
        plt.colorbar()
        plt.title("Distance map (U-matrix)")
        
        plt.subplot(222)
        plt.pcolor(density_map.T, cmap='YlGn')
        plt.colorbar()
        plt.title("Density Map")
    
        plt.subplot(223)
        plt.pcolor(magnitude_map.T, cmap="plasma" + "_r")
        plt.colorbar()
        plt.title("HSC I Magnitude Map")
    
        plt.subplot(224)
        plt.pcolor(z_map.T, cmap="plasma")
        plt.colorbar()
        plt.title("True Redshift Map")
        if save != False:
            plt.savefig(save + f"{record['type']}_{len(data[0])}_{dim}_{sigma}_{rate}_{record['iter']}.jpg", bbox_inches = "tight")
        
        plt.show()
    
        record["density"] = density_map
        record["magnitude"] = magnitude_map
        record["z"] = z_map
    
        try:
            with open(file, 'wb') as handle:
                pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)
        except KeyboardInterrupt:
            print('KeyboardInterrupt caught, data saved.')

In [7]:
data = pd.DataFrame(data_all[:, 0::2]).dropna().to_numpy()
info = df_L3_info[-pd.DataFrame(data_all[:, 0::2]).isna()[0]]
err = data_all[0, 1::2]

lupmag = -np.arcsinh(data / 10 ** 3 / (2 * 1.042 * err / 10 ** 3))
color = -np.diff(lupmag, axis = -1)
proc_data = (color- np.mean(color, axis=0)) / np.std(color, ddof = 1, axis=0)

  info = df_L3_info[-pd.DataFrame(data_all[:, 0::2]).isna()[0]]


In [10]:
grid_search(data = proc_data, err = np.zeros_like(proc_data), folder = "D:/SPHEREx_SOM/record/3rd_grid_search/noisy_lupcolor_shallow/SOM/")

Progress: 49.0%, param: [30, 29, 0.1]9]]



Progress: 99.0%, param: [30, 44, 0.1]9]]

In [None]:
parameter_space(save = "D:/SPHEREx_SOM/record/3rd_grid_search/noisy_lupcolor_shallow/diagram/", folder = "D:/SPHEREx_SOM/record/3rd_grid_search/noisy_lupcolor_shallow/SOM/")

In [None]:
visual_map(folder = "D:/SPHEREx_SOM/record/3rd_grid_search/noisy_lupcolor_shallow/SOM/noisy_lupcolor_shallow", dim = 30, sigma = 40, rate = 0.096, iteration = 1, data = proc_data, err = np.zeros_like(proc_data), info_data = info, save = "D:/SPHEREx_SOM/record/3rd_grid_search/noisy_lupcolor_shallow/diagram/")