In [9]:
from astropy.table import Table
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from minisom import MiniSom
import pickle
import os

In [10]:
points = [i / 10 for i in range(1, 51)]
a = []
for i in points:
    for j in points:
        b = {}
        b["dim"] = 30
        b["rate"] = i
        b["sigma"] = j
        b["distance"] = "euclidean"
        b["type"] = "noisy_color_shallow"
        b["iter"] = 2
        a.append(b)

with open("D:/SPHEREx_SOM/record/3rd_grid_search/noisy_lup_shallow/noisy_lup_shallow.pkl",'wb') as handle:
    pickle.dump(a, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [11]:
df_L3_info = Table.read('../dataset/L3_COSMOS2020_Richard_RefCat_2023DEC4_info.fits')
df_L3_info = df_L3_info.to_pandas().sort_values(by = "cosmos_id")

fname = 'D:/SPHEREx_SOM/dataset/sphx_refcat/Noiseless_phot_cosmos_nolines_refcat30k.txt'
data_noiseless = np.loadtxt(fname)[:,3:]

fname = 'D:/SPHEREx_SOM/dataset/sphx_refcat/NoisySphx_shallow_nolines_refcat30k.txt'
data_all = np.loadtxt(fname)[:,3:]

fname = 'D:/SPHEREx_SOM/dataset/sphx_refcat/NoisySphx_deep_nolines_refcat30k.txt'
data_deep = np.loadtxt(fname)[:,3:]

fname = 'D:/SPHEREx_SOM/dataset/sphx_refcat/SPHEREx_1sigma_noise.txt'
data_1sig  = np.loadtxt(fname, skiprows=1)
wl = data_1sig[:,0]
sigma_all = data_1sig[:,1]
sigma_deep = data_1sig[:,2]

In [12]:
def grid_search(data, err, file):
    n = 0
    with open(file, "rb") as fh:
        record = pickle.load(fh)
    if "som" not in record[-1]:
        for i in record:
            n += 1
            if "som" not in i:
                som = MiniSom(i["dim"], i["dim"], data.shape[1], sigma = i["sigma"], learning_rate = i["rate"], activation_distance = i["distance"])
                som.random_weights_init(data)
                som.train(data, err, num_iteration = i["iter"], use_epochs = True) 
                i["topo_err"] = som.topographic_error(data)
                i["quan_err"] = som.quantization_error(data)
                i["som"] = som
                i["band_cut"] = len(data[0])
                i["mag_cut"] = len(data)
                i["preproc"] = "default"
                i["topo_err"] = som.topographic_error(data)
                i["quan_err"] = som.quantization_error(data)
                print(f"File: {file}, progress: {round(n / len(record) * 100, 2)}%, param: {[i["dim"], i["sigma"], i["rate"]]}", end = "\r")
                try:
                    with open(file, 'wb') as handle:
                        pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)
                except KeyboardInterrupt:
                    print('KeyboardInterrupt caught, data saved.')
    else:
        pass

In [13]:
def parameter_space(save, file):
    with open(file, "rb") as fh:
        record = pickle.load(fh)
    x_tick_f = list(set([i["rate"] for i in record]))
    x_tick_f.sort()
    y_tick_f = list(set([i["sigma"] for i in record]))
    y_tick_f.sort()
    x_tick_i = [i + 0.5 for i in range(0, len(x_tick_f))]
    y_tick_i = [i + 0.5 for i in range(0, len(y_tick_f))]
    
    quan_map30 = np.zeros((len(y_tick_f), len(x_tick_f)))
    topo_map30 = np.zeros((len(y_tick_f), len(x_tick_f)))
    quan_map50 = np.zeros((len(y_tick_f), len(x_tick_f)))
    topo_map50 = np.zeros((len(y_tick_f), len(x_tick_f)))
    quan_map70 = np.zeros((len(y_tick_f), len(x_tick_f)))
    topo_map70 = np.zeros((len(y_tick_f), len(x_tick_f)))
    quan_map90 = np.zeros((len(y_tick_f), len(x_tick_f)))
    topo_map90 = np.zeros((len(y_tick_f), len(x_tick_f)))

    x_index = dict(zip(x_tick_f, [i - 0.5 for i in x_tick_i]))
    y_index = dict(zip(y_tick_f, [i - 0.5 for i in y_tick_i]))

    for i in record:
        if i["dim"] == 30:
            topo_map30[int(y_index[i["sigma"]]), int(x_index[i["rate"]])] = i["topo_err"]
            quan_map30[int(y_index[i["sigma"]]), int(x_index[i["rate"]])] = i["quan_err"]

        elif i["dim"] == 50:
            topo_map50[int(y_index[i["sigma"]]), int(x_index[i["rate"]])] = i["topo_err"]
            quan_map50[int(y_index[i["sigma"]]), int(x_index[i["rate"]])] = i["quan_err"]

        elif i["dim"] == 70:
            topo_map70[int(y_index[i["sigma"]]), int(x_index[i["rate"]])] = i["topo_err"]
            quan_map70[int(y_index[i["sigma"]]), int(x_index[i["rate"]])] = i["quan_err"]

        elif i["dim"] == 90:
            topo_map90[int(index[i["sigma"]]), int(index[i["rate"]])] = i["topo_err"]
            quan_map90[int(index[i["sigma"]]), int(index[i["rate"]])] = i["quan_err"]

    if np.sum(quan_map30) == 0:
        quan_map30 = np.ones((len(y_tick_f), len(x_tick_f)))
    if np.sum(quan_map50) == 0:
        quan_map50 = np.ones((len(y_tick_f), len(x_tick_f)))
    if np.sum(quan_map70) == 0:
        quan_map70 = np.ones((len(y_tick_f), len(x_tick_f)))
    if np.sum(quan_map90) == 0:
        quan_map90 = np.ones((len(y_tick_f), len(x_tick_f)))
        
    topo_maps = [topo_map30, topo_map50, topo_map70, topo_map90]
    quan_maps = [quan_map30, quan_map50, quan_map70, quan_map90]

    x_tick_f = np.round(x_tick_f, 3)
    
    plt.close()
    plt.figure(figsize = (30, 20))

    for i in range(1, 5):
        dim = (i * 2 + 1) * 10
        plt.subplot(220 + i)
        plt.hist(topo_maps[i - 1].reshape(-1), bins = 20)
        plt.xlabel("Topological Error", fontsize=20)
        plt.ylabel("# of Glaxies", fontsize=20)
    
    if save:
        plt.savefig(save + f"{record[0]['type']}_{record[0]['band_cut']}_topo_err_dist.jpg", bbox_inches = "tight")    
    plt.show()

    plt.close()
    plt.figure(figsize = (30, 20))
    for i in range(1, 5):
        dim = (i * 2 + 1) * 10
        plt.subplot(220 + i)
        plt.hist(np.log10(quan_maps[i - 1]).reshape(-1), bins = 20)
        plt.xlabel("log(Quantization Error)", fontsize=20)
        plt.ylabel("# of Glaxies", fontsize=20)
        plt.yscale("log")
        
    if save:
        plt.savefig(save + f"{record[0]['type']}_{record[0]['band_cut']}_quan_err_dist.jpg", bbox_inches = "tight")  
    plt.show()
    
    plt.close()
    plt.figure(figsize = (30, 20))
    for i in range(1, 5):
        dim = (i * 2 + 1) * 10
        plt.subplot(220 + i)
        plt.pcolor(topo_maps[i - 1], cmap = 'YlGn')
        plt.title(f"Topological Error (Dim: {dim}X{dim} / Cut: both)", fontsize=20)
        plt.xlabel("Rate", fontsize=15)
        plt.ylabel("Sigma", fontsize=15)
        plt.xticks(x_tick_i, x_tick_f, fontsize=8)
        plt.yticks(y_tick_i, y_tick_f, fontsize=8)
        plt.colorbar().ax.tick_params(labelsize=15)

    if save:
        plt.savefig(save + f"{record[0]['type']}_{record[0]['band_cut']}_topo_err_para_space.jpg", bbox_inches = "tight")  
    plt.show()
    
    plt.close()
    plt.figure(figsize = (30, 20))
    for i in range(1, 5):
        dim = (i * 2 + 1) * 10
        plt.subplot(220 + i)
        plt.pcolor(np.log10(quan_maps[i - 1]), cmap = 'YlGn')
        plt.title(f"Quantizationl Error (Dim: {dim}X{dim} / Cut: both)", fontsize=20)
        plt.xlabel("Rate", fontsize=15)
        plt.ylabel("Sigma", fontsize=15)
        plt.xticks(x_tick_i, x_tick_f, fontsize=8)
        plt.yticks(y_tick_i, y_tick_f, fontsize=8)
        plt.colorbar().ax.tick_params(labelsize=15)

    if save != False:
        plt.savefig(save + f"{record[0]['type']}_{record[0]['band_cut']}_quan_err_para_space.jpg", bbox_inches = "tight")  
    plt.show()

In [14]:
def visual_map(file, dim, sigma, rate, data, err, info_data, save = False):
    with open(file, "rb") as fh:
        record = pickle.load(fh)
    for k in record:
        if k["dim"] == dim and round(k["rate"], 3) == rate and k["sigma"] == sigma:
            pass
        else:
            continue
        som = k["som"]
        print(f"Topological error: {som.topographic_error(data)}")
        print(f"Quantization error: {som.quantization_error(data)}")
        print(f"Topological error: {k["topo_err"]}")
        print(f"Quantization error: {k["quan_err"]}")        
        if "z" in k:
            density_map = k["density"]
            magnitude_map = k["magnitude"]
            z_map = k["z"]
            
            %matplotlib inline
            plt.close()
            plt.figure(figsize = (30, 20))
            plt.subplot(221)
            plt.pcolor(som.distance_map().T, cmap='YlGn')  # plotting the distance map as background
            plt.colorbar()
            plt.title("Distance map (U-matrix)")
        
            plt.subplot(222)
            plt.pcolor(density_map.T, cmap='YlGn')
            plt.colorbar()
            plt.title("Density Map")
        
            plt.subplot(223)
            plt.pcolor(magnitude_map.T, cmap="plasma" + "_r")
            plt.colorbar()
            plt.title("HSC I Magnitude Map")
        
            plt.subplot(224)
            plt.pcolor(z_map.T, cmap="plasma")
            plt.colorbar()
            plt.title("True Redshift Map")
            if save != False:
                plt.savefig(save + f"{k['type']}_{len(data[0])}_{dim}_{sigma}_{rate}_{k['iter']}.jpg", bbox_inches = "tight")
            
            plt.show()
            continue
        else:
            pass
        
        density_map = np.zeros(dim ** 2).reshape(dim, dim)
        magnitude_map = np.zeros(dim ** 2).reshape(dim, dim)
        magnitude_map.fill(np.nan)
        z_map = np.zeros(dim ** 2).reshape(dim, dim)
        z_map.fill(np.nan)
        
        labels_map_1 = som.labels_map(data, err, [0] * len(data))
        labels_map_2 = som.labels_map(data, err, tuple(map(tuple, np.concatenate((np.expand_dims(info_data["HSC_i_MAG"].values, axis = 1), np.expand_dims(info_data["z_true"].values, axis = 1)), axis = -1))))
    
        for i in labels_map_1.keys():
            density_map[int(list(i)[0]), int(list(i)[1])] = labels_map_1[i][0]
        
        for i in labels_map_2.keys():
            properties = np.mean(np.array(list(labels_map_2[i].keys())), axis = 0)
            magnitude_map[int(list(i)[0]), int(list(i)[1])] = properties[0]
            z_map[int(list(i)[0]), int(list(i)[1])] = properties[1]

        %matplotlib inline
        plt.close()
        plt.figure(figsize = (30, 20))
        plt.subplot(221)
        plt.pcolor(som.distance_map().T, cmap='YlGn')  # plotting the distance map as background
        plt.colorbar()
        plt.title("Distance map (U-matrix)")
        
        plt.subplot(222)
        plt.pcolor(density_map.T, cmap='YlGn')
        plt.colorbar()
        plt.title("Density Map")
    
        plt.subplot(223)
        plt.pcolor(magnitude_map.T, cmap="plasma" + "_r")
        plt.colorbar()
        plt.title("HSC I Magnitude Map")
    
        plt.subplot(224)
        plt.pcolor(z_map.T, cmap="plasma")
        plt.colorbar()
        plt.title("True Redshift Map")
        if save != False:
            plt.savefig(save + f"{k['type']}_{len(data[0])}_{dim}_{sigma}_{rate}_{k['iter']}.jpg", bbox_inches = "tight")
        
        plt.show()
    
        k["density"] = density_map
        k["magnitude"] = magnitude_map
        k["z"] = z_map

        try:
            with open(file, 'wb') as handle:
                pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)
        except KeyboardInterrupt:
            print('KeyboardInterrupt caught, data saved.')

In [15]:
data = pd.DataFrame(data_deep[:, 0::2]).dropna().to_numpy()
info = df_L3_info[-pd.DataFrame(data_deep[:, 0::2]).isna()[0]]
err = data_deep[0, 1::2]

lupmag = np.arcsinh(data / 10 ** 6 / (1.042 * err / 10 ** 6))
proc_data = (lupmag- np.mean(lupmag, axis=0)) / np.std(lupmag, ddof = 1, axis=0)

  info = df_L3_info[-pd.DataFrame(data_deep[:, 0::2]).isna()[0]]


In [16]:
grid_search(data = proc_data, err = np.zeros_like(proc_data), file = "D:/SPHEREx_SOM/record/3rd_grid_search/noisy_lup_shallow/noisy_lup_shallow.pkl")

  return sqrt(-2 * cross_term + input_data_sq + weights_flat_sq.T)


File: D:/SPHEREx_SOM/record/3rd_grid_search/noisy_lup_shallow/noisy_lup_shallow.pkl, progress: 55.28%, param: [30, 3.2, 2.8]

KeyboardInterrupt: 