In [2]:
import numpy as np
import pandas as pd
import os
import random
#import seaborn as sns
import matplotlib.pyplot as plt
import time

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset

from easy_ntk import calculate_NTK
from torch.utils.tensorboard import SummaryWriter

# import argparse
# parser = argparse.ArgumentParser()
# parser.add_argument('--ZMIN',default=0.0,type=float)
# parser.add_argument('--ZMAX',default=1.0,type=float)
# parser.add_argument('--BATCH_SIZE',default=64,type=int)
# parser.add_argument('--N_EPOCHS',default=100,type=int)
# parser.add_argument('--LR',default=1e-4,type=float)
# parser.add_argument('--WIDTH',default=500,type=int)
# parser.add_argument('--TRAIN_FLAG',default=False,type=bool)
# parser.add_argument('--NTK_FLAG',default=False,type=bool)
# parser.add_argument('--NTK_POINTS',default=1000,type=int)
# args = parser.parse_args()

# ZMIN = args.ZMIN
# ZMAX = args.ZMAX
# BATCH_SIZE = args.BATCH_SIZE
# N_EPOCHS = args.N_EPOCHS
# LR = args.LR
# WIDTH = args.WIDTH
# TRAIN_FLAG = args.TRAIN_FLAG
# NTK_FLAG = args.NTK_FLAG
# NTK_POINTS = args.NTK_POINTS

In [3]:
writer = SummaryWriter(log_dir=None) #this defaults to saving to './runs'
#calling this starts a new run, so dont be afraid to.

ZMIN = 0.0
ZMAX = 1.0
BATCH_SIZE = 256
N_EPOCHS = 10
LR = 1e-5
WIDTH = 500
TRAIN_FLAG = True
NTK_FLAG = True
NTK_POINTS = 500
DEVICE='cpu'
###########################################
use_columns = ['ZBEST','gFKronFlux', 'rFKronFlux', 'iFKronFlux', 'zFKronFlux',
       'yFKronFlux', 'gFPSFFlux', 'rFPSFFlux', 'iFPSFFlux',
       'zFPSFFlux', 'yFPSFFlux', 'gFApFlux', 'rFApFlux', 'iFApFlux',
       'zFApFlux', 'yFApFlux','gFmeanflxR5', 'rFmeanflxR5', 'iFmeanflxR5',
       'zFmeanflxR5', 'yFmeanflxR5','gFmeanflxR6',
       'rFmeanflxR6', 'iFmeanflxR6', 'zFmeanflxR6', 'yFmeanflxR6',
       'gFmeanflxR7', 'rFmeanflxR7', 'iFmeanflxR7',
       'zFmeanflxR7', 'yFmeanflxR7','raMean','decMean']

X_COLUMNS = ['gFKronFlux', 'rFKronFlux', 'iFKronFlux', 'zFKronFlux',
       'yFKronFlux', 'gFPSFFlux', 'rFPSFFlux', 'iFPSFFlux',
       'zFPSFFlux', 'yFPSFFlux', 'gFApFlux', 'rFApFlux', 'iFApFlux',
       'zFApFlux', 'yFApFlux','gFmeanflxR5', 'rFmeanflxR5', 'iFmeanflxR5',
       'zFmeanflxR5', 'yFmeanflxR5','gFmeanflxR6',
       'rFmeanflxR6', 'iFmeanflxR6', 'zFmeanflxR6', 'yFmeanflxR6',
       'gFmeanflxR7', 'rFmeanflxR7', 'iFmeanflxR7',
       'zFmeanflxR7', 'yFmeanflxR7']
 
DF = pd.read_csv('./../DATA/no_repeats.csv',usecols=use_columns) #!! Download this file!, see README

Y = DF['ZBEST'].values
X = DF[X_COLUMNS].values

#Pair down to Y < 1
X = X[Y<ZMAX]
Y = Y[Y<ZMAX]

X = X[Y>=ZMIN]
Y = Y[Y>=ZMIN]

#m = -2.5/ln(10) * [asinh((f/f0)/(2b)) + ln(b)]

#The asinh magnitudes are characterized by a softening parameter b, the typical 1-sigma 
#noise of the sky in a PSF aperture in 1'' seeing. The relation between detected flux f and asinh magnitude m is:

f_0 = 3631 #Jy

#https://iopscience.iop.org/article/10.1088/0004-637X/756/2/158/pdf table1

#1 square arcsecond sky background magnitude, use it to find b
g_mu = 21.92
r_mu = 20.83
i_mu = 19.79
z_mu = 19.24
y_mu = 18.24

b_g = np.exp((g_mu*np.log(10)/-2.5) - np.arcsinh((1/(2*f_0))))
b_r = np.exp((r_mu*np.log(10)/-2.5) - np.arcsinh((1/(2*f_0))))
b_i = np.exp((i_mu*np.log(10)/-2.5) - np.arcsinh((1/(2*f_0))))
b_z = np.exp((z_mu*np.log(10)/-2.5) - np.arcsinh((1/(2*f_0))))
b_y = np.exp((y_mu*np.log(10)/-2.5) - np.arcsinh((1/(2*f_0))))

def convert_flux_to_luptitude(f,b,f_0=3631):
    return -2.5/np.log(10) * (np.arcsinh((f/f_0)/(2*b)) + np.log(b))

#g
X[:,[0,5,10,15,20,25]] = convert_flux_to_luptitude(X[:,[0,5,10,15,20,25]],b=b_g)

#r
X[:,[1,6,11,16,21,26]] = convert_flux_to_luptitude(X[:,[1,6,11,16,21,26]],b=b_r)

#i
X[:,[2,7,12,17,22,27]] = convert_flux_to_luptitude(X[:,[2,7,12,17,22,27]],b=b_i)

#z
X[:,[3,8,13,18,23,28]] = convert_flux_to_luptitude(X[:,[3,8,13,18,23,28]],b=b_z)

#y
X[:,[4,9,14,19,24,29]] = convert_flux_to_luptitude(X[:,[4,9,14,19,24,29]],b=b_y)
###########################################
#Robust to outliers

MEANS = np.median(X,axis=0)
IQR = np.quantile(X,axis=0,q=[0.75,0.25])
STDS = (IQR[0,] - IQR[1,]) / 1.34896


X = (X - MEANS)/STDS

#robust to missing data
X[np.isnan(X)] = -20
#Fix outliers
X[X<-20] = -20
X[X>20] = 20

print(np.shape(X))

indices = np.arange(len(X))
SEED=0
random.seed(SEED)

random.shuffle(indices)

train_indices = indices[int(0*len(indices)):int(0.8*len(indices))]
test_indices = indices[int(0.8*len(indices)):int(0.9*len(indices))]
val_indices = indices[int(0.9*len(indices)):int(1.0*len(indices))]

X = X.astype(np.float32)
Y = Y.astype(np.float32)

X_train = X[train_indices]
Y_train = Y[train_indices]

X_test = X[test_indices]
Y_test = Y[test_indices]

X_val = X[val_indices]
Y_val = Y[val_indices]
###########################################
X_train_tensor = torch.from_numpy(X_train)
X_test_tensor = torch.from_numpy(X_test)
X_val_tensor = torch.from_numpy(X_val)

Y_train_tensor = torch.from_numpy(Y_train)
Y_test_tensor = torch.from_numpy(Y_test)
Y_val_tensor = torch.from_numpy(Y_val)

train_dataset = TensorDataset(X_train_tensor,Y_train_tensor)
val_dataset = TensorDataset(X_val_tensor,Y_val_tensor)
test_dataset = TensorDataset(X_test_tensor,Y_test_tensor)

train_dataloader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True)
val_dataloader = DataLoader(val_dataset,batch_size=BATCH_SIZE,shuffle=False)
test_dataloader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False)
###########################################
#ARCHITECTURE:
def relu(X,normalize=True):
    X = F.relu(X)
    if normalize:
        return np.sqrt(2*np.pi/(np.pi-1))*(X-1/np.sqrt(2*np.pi))
    else:
        return X

N_features = np.shape(X_test)[1]
class MLP(torch.nn.Module):
    def __init__(self,):
        super(MLP, self).__init__()
        
        self.l1 = nn.Linear(N_features,WIDTH,bias=True)
        self.l2 = nn.Linear(WIDTH,1,bias=True)
    
    def forward(self,x):
        x = relu(self.l1(x))
        x = self.l2(x)
        return x

def init_weights(m): #this is lecunn initialization?
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight.data) / np.sqrt(m.weight.shape[1])
        if isinstance(m.bias,torch.Tensor):
            nn.init.normal_(m.bias.data) 
    
model = MLP()
model.to(DEVICE)
model.apply(init_weights)
optim = torch.optim.SGD(model.parameters(), lr=LR)
N_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
###########################################
#TRAINING/EVAL LOOPS:
loss = torch.nn.MSELoss()
def train(epoch):
    losses=[]
    for idx,(x,label) in enumerate(train_dataloader):
        x = x.to(DEVICE)
        label = label.to(DEVICE)
        y = model(x).view(-1)
        output = loss(label,y)
        losses.append(output.item())
        output.backward()
        optim.step()
        optim.zero_grad()
    return np.mean(losses)
        
    
def val(epoch):
    model.eval()
    losses=[]
    with torch.no_grad():
        for idx,(x,label) in enumerate(val_dataloader):
            x = x.to(DEVICE)
            label = label.to(DEVICE)
            y = model(x).view(-1)
            output = loss(label,y)
            losses.append(output.item())
    return np.mean(losses)
        
def test(epoch):
    model.eval()
    losses=[]
    all_outputs = []
    with torch.no_grad():
        for idx,(x,label) in enumerate(test_dataloader):
            x = x.to(DEVICE)
            label = label.to(DEVICE)
            y = model(x).view(-1)
            output = loss(label,y)
            losses.append(output.item())
            all_outputs.append(y.detatch().to('cpu').numpy())
    return np.mean(losses), np.array(all_outputs)
##################################################

if NTK_FLAG:
    NTK_start=time.time()
    NTK = calculate_NTK(model,X_train_tensor[0:NTK_POINTS])
    NTK_end = time.time()
    print('NTK time: {:4e}s'.format(NTK_end-NTK_start),flush=True)
    NTK_time = NTK_end-NTK_start
    eigenvalues, eigenvectors = torch.torch.linalg.eigh(NTK)
    writer.add_histogram('NTK',torch.log(eigenvalues)/(np.log(10)),0,bins='auto')
    writer.add_scalar('NTK/Condition_number',torch.min(eigenvalues)/torch.max(eigenvalues),0)
    writer.add_scalar('NTK/min_eigenvalue',torch.min(eigenvalues),0)
    writer.add_scalar('NTK/max_eigenvalue',torch.max(eigenvalues),0)


##################################################   
all_train_loss=[]
all_val_loss=[]
if TRAIN_FLAG:
    for epoch in range(1,N_EPOCHS+1):
        epoch_start=time.time()
        train_loss=train(epoch)
        all_train_loss.append(train_loss)
        train_end=time.time()
        val_loss=val(epoch)
        all_val_loss.append(val_loss)
        val_end=time.time()
        
        writer.add_scalar('Loss/train',train_loss,epoch)
        writer.add_scalar('Loss/validation',val_loss,epoch)
        
        print('train loss: {:4e} val loss: {:4e} train time: {:4e}s'.format(train_loss,val_loss,train_end-epoch_start))
        
        for name, param in model.named_parameters():
            if param.requires_grad:
                writer.add_histogram(name,param.data,epoch,bins='auto')
                
        for name, param in model.named_parameters():
            if param.requires_grad and 'weight' in name:
                writer.add_histogram('X_'+name,torch.linalg.eigh(torch.matmul(param.data,param.data.T))[0],epoch,bins='auto')
                
        NTK_start=time.time()
        NTK = calculate_NTK(model,X_train_tensor[0:NTK_POINTS])
        NTK_end = time.time()
        NTK_time = NTK_end-NTK_start
        eigenvalues, eigenvectors = torch.torch.linalg.eigh(NTK)
        writer.add_histogram('NTK',torch.log(eigenvalues)/(np.log(10)),epoch,bins='auto')
        writer.add_scalar('NTK/Condition_number',torch.min(eigenvalues)/torch.max(eigenvalues),epoch)
        writer.add_scalar('NTK/min_eigenvalue',torch.min(eigenvalues),epoch)
        writer.add_scalar('NTK/max_eigenvalue',torch.max(eigenvalues),epoch)
        #print('NTK time: {:4e}s'.format(NTK_end-NTK_start),flush=True)
#########################################
writer.add_hparams({'ZMIN':ZMIN,
                      'ZMAX':ZMAX,
                      'BATCH_SIZE':BATCH_SIZE,
                      'N_EPOCHS ':N_EPOCHS,
                      'LR':LR,
                      'WIDTH':WIDTH,
                      'NTK_POINTS':NTK_POINTS,
                      'NUM_PARAMS':N_parameters},{})

writer.flush() #writes everything to disk

FileNotFoundError: [Errno 2] No such file or directory: './../DATA/no_repeats.csv'

In [None]:
def test():
    model.eval()
    losses=[]
    all_outputs = []
    with torch.no_grad():
        for idx,(x,label) in enumerate(test_dataloader):
            x = x.to(DEVICE)
            label = label.to(DEVICE)
            y = model(x).view(-1)
            output = loss(label,y)
            losses.append(output.item())
            all_outputs.extend(y.detach().to('cpu').numpy())
    return np.mean(losses), np.array(all_outputs)

test_loss, out_y_test = test()

In [None]:
plt.plot(Y_test,out_y_test,'.')
plt.ylabel('Z real')
plt.xlabel('Z est ')
plt.title('Why this so bad?')

In [35]:
writer.close()