In [15]:
from astropy.table import Table
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from minisom import MiniSom
import pickle
import os

In [11]:
b = []
for i in range(1, 5):
    dim = (2 * i + 1) * 10
    for j in range(1, 11):
        sigma = j
        for k in range(1, 11):
            rate = k / 10

            a = {}
            a["dim"] = dim
            a["sigma"] = sigma
            a["rate"] = rate
            b.append(a)

In [16]:
df_L3_info = Table.read('../dataset/L3_COSMOS2020_Richard_RefCat_2023DEC4_info.fits')
df_L3_info = df_L3_info.to_pandas()

In [17]:
L3_phot_data = np.loadtxt('../dataset/L3_COSMOS2020_Richard_RefCat_2023DEC4_averaged_phot102.txt')
phot_data = L3_phot_data[:,3::2]
phot_err_data = L3_phot_data[:,4::2]

In [18]:
def remove_low_SNR(data, data_err, info_data, mode, mag_cut):
    i_band = info_data["HSC_i_MAG"].values
    if mode == "mag_cut" or mode == "both":
        if type(mag_cut) == type(1):
            max_cut = np.mean(i_band) - np.std(i_band, ddof = 1) * mag_out
            data = data[i_band < max_cut]
            data_err = data[i_band < max_cut]
            info_data = info_data[np.logical_and((i_band < mag_cut[1]), (i_band > mag_cut[0]))]
        else:
            data = data[np.logical_and((i_band < mag_cut[1]), (i_band > mag_cut[0]))]
            data_err = data_err[np.logical_and((i_band < mag_cut[1]), (i_band > mag_cut[0]))]
            info_data = info_data[np.logical_and((i_band < mag_cut[1]), (i_band > mag_cut[0]))]
    if mode == "band_cut" or mode == "both":
        data = data[:, : len(data[0]) - 34]
        data_err = data_err[:, : len(data_err[0]) - 34]
        
    return data, data_err, info_data

In [19]:
def grid_search(data, file):
    n = 0
    with open(file, "rb") as fh:
        record = pickle.load(fh)
    if "som" not in record[-1]:
        for i in record:
            n += 1
            if "som" not in i:
                som = MiniSom(i["dim"], i["dim"], data.shape[1], sigma = i["sigma"], learning_rate = i["rate"])
                som.random_weights_init(data)
                som.train(data = data, num_iteration = 1, use_epochs = True) 
                i["topo_err"] = som.topographic_error(data)
                i["quan_err"] = som.quantization_error(data)
                i["som"] = som
                i["band_cut"] = len(data[0])
                i["mag_cut"] = len(data)
                i["preproc"] = "default"
                i["topo_err"] = som.topographic_error(data)
                i["quan_err"] = som.quantization_error(data)
                print(f"File: {file}, progress: {round(n / len(record) * 100, 2)}%, param: {[i["dim"], i["sigma"], i["rate"]]}", end = "\r")
                try:
                    with open(file, 'wb') as handle:
                        pickle.dump(record, handle, protocol=pickle.HIGHEST_PROTOCOL)
                except KeyboardInterrupt:
                    print('KeyboardInterrupt caught, data saved.')
    else:
        pass

In [20]:
data, data_err, info_data = remove_low_SNR(data = phot_data, data_err = phot_err_data, info_data = df_L3_info, mode = "band_cut", mag_cut = [0, 22])
data = (data - np.mean(data, axis=0)) / np.std(data, axis=0)
data.shape

(29685, 68)

In [21]:
for j in os.listdir("D:/SPHEREx_SOM/record/grid_search/band_cut/"):
    if "check" in j:
        continue
    file = "D:/SPHEREx_SOM/record/grid_search/band_cut/" + j
    grid_search(data = data, file = file)

File: D:/SPHEREx_SOM/record/grid_search/band_cut/low_sigma_high_rate.pkl, progress: 98.5%, param: [90, 1.0, 4]]

KeyboardInterrupt: 