# Flatw'rm2 training

Use this notebook, if you want to retrain the flatw'rm2 network for any reason, e.g. you have a training set of light curves with a cadence that doesn't work well with the current weight file. Make sure to include a large number of flagged light curves along with some non-flaring targets. 

You can either start the training from scratch, or use transfer learning to reduce training time by loading the best current weight file.

In [None]:
from matplotlib import pyplot as plt
import numpy as np

import tensorflow as tf

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, GRU, Dropout, LSTM, Bidirectional
from tensorflow.keras.preprocessing.sequence import TimeseriesGenerator
from tensorflow.keras.utils import Sequence


from tensorflow.keras.callbacks import TensorBoard, EarlyStopping, ModelCheckpoint


import datetime
from time import localtime, strftime

from scipy.signal import medfilt


If you have multiple GPUs in your machine, you can change this to your preferred unit.

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

This function will be used for feeding the input to the neural network during the training. As we are working with time-domain data and recurrent networks, the order, in which the data is fed to the network is not arbitrary. Also, batch number shouldn't be increased too high, the batch lenghts should be longer than the typical timescale of the flares.

In [None]:
class SplitGenerator(Sequence):
    #Because... python.
    def proper_round(self, val):
        if (float(val) % 1) >= 0.5:
            x = np.ceil(val)
        else:
            x = round(val)
        return x
    
    def __init__(self, x_set, y_set, batch_size=1, length=10):
        self.x, self.y = x_set, y_set
        self.batch_size = batch_size
        self.length = length
        if self.batch_size * self.length > len(self.y):
            raise ValueError("Batch size and sample size don't match! batch_size=",self.batch_size,"length=",self.length)
    def __getitem__(self, idx):
        batch_x = np.zeros( (self.batch_size, self.length) )
        batch_y = np.zeros( self.batch_size )

        if idx > self.__len__():
            raise ValueError("Requested index too large")

        for b in np.arange(self.batch_size):
            try:
                batch_x[b] = self.x[ b * (len(self.y) // self.batch_size) + idx  : b * (len(self.y)//self.batch_size) + idx + self.length ].reshape(self.length)
                batch_y[b] = self.proper_round ( np.mean ( self.y[ b * (len(self.y) // self.batch_size) + idx  : b * (len(self.y)//self.batch_size) + idx + self.length ] ) )
            except:
                batch_x_temp = self.x[ b * (len(self.y) // self.batch_size) + idx  : b * (len(self.y)//self.batch_size) + idx + self.length ]
                endpoints  = self.x[ -self.length+len(batch_x_temp) : ][::-1]
                batch_x[b] = np.concatenate((batch_x_temp,endpoints)).reshape(self.length)
                
                batch_y_temp = self.y[ b * (len(self.y) // self.batch_size) + idx  : b * (len(self.y)//self.batch_size) + idx + self.length ]
                endpoints  = self.y[ -self.length+len(batch_y_temp) : ][::-1]
                batch_y_temp = np.concatenate((batch_y_temp,endpoints)).reshape(self.length)
                batch_y[b] = self.proper_round ( np.mean ( batch_y_temp ) )
                
                del endpoints,batch_x_temp,batch_y_temp
                
        return batch_x.reshape(self.batch_size, self.length,1) , batch_y.reshape( self.batch_size, 1  )
    def __len__(self):
        return int( np.ceil( len(self.y) / self.batch_size ) ) - 1

This function is used to read the data files. If your cadence is not 1 minute, change it here, in case of long-cadence data, you might want to consider changing the length of the gaps, where the gaps get interpolated and filled up.

In [None]:
def read_data(filename, cadence=1./60/24):
    from scipy.interpolate import interp1d
    from scipy.stats import median_absolute_deviation
    
    from numpy.polynomial.polynomial import Polynomial
    
    time, flux, is_flare = np.genfromtxt(filename).T
    


    allgaps = np.diff(time) > 10./60/24
    allgaps = np.where(allgaps>0)[0]


    for ii in range(len(allgaps),0,-1):  
        ii-=1
        begin=time[allgaps[ii]]
        end=time[allgaps[ii]+1]
        gaplength=end-begin
        to_check_std = 0.015 #day

        um1 = (time>=time[allgaps[ii]]-to_check_std) &  (time<=time[allgaps[ii]])
        um2 = (time>=time[allgaps[ii]+1]) &  (time<=time[allgaps[ii]+1]+to_check_std)
        umplot = (time>=time[allgaps[ii]+1]-5*gaplength) &  (time<=time[allgaps[ii]+1]+5*gaplength)

        down,up = np.percentile(flux[um1],[5,95])
        beginstd = np.std( flux[um1][ (flux[um1]>=down) & (flux[um1]<=up) ] )
        down,up = np.percentile(flux[um2],[5,95])
        endstd = np.std( flux[um2][ (flux[um2]>=down) & (flux[um2]<=up) ] )
        beginmean = np.mean( flux[allgaps[ii]-5:allgaps[ii]+1] )
        endmean = np.mean( flux[allgaps[ii]+1:allgaps[ii]+5] )
        npoints = int(gaplength/np.mean(np.diff(time)))

        if beginstd<endstd:
            filling = np.random.normal(loc=0,scale=beginstd,size=npoints)
        else:
            filling = np.random.normal(loc=0,scale=endstd,size=npoints)

        timefilling  =np.linspace(begin,end,npoints+2)[1:-1]

        time=np.insert(time, allgaps[ii]+1, timefilling)
        flux=np.insert(flux, allgaps[ii]+1, filling+np.linspace(beginmean,endmean,npoints) )
        is_flare=np.insert(is_flare, allgaps[ii]+1, np.zeros_like(timefilling) )
        #plt.figure(figsize=(10,3))
        #plt.scatter(lc[umplot,0],lc[umplot,1],s=5,c='k')
        #plt.plot(lc[um1,0],lc[um1,1],c='C0')
        #plt.plot(lc[um2,0],lc[um2,1],c='C0')

        #plt.scatter(timefilling,filling+np.linspace(beginmean,endmean,npoints),c='r',s=3)
        #plt.plot(timefilling,filling+np.linspace(beginmean,endmean,npoints),'r')
        #plt.axvspan(lc[allgaps[ii],0],lc[allgaps[ii]+1,0],color='g',alpha=0.1,zorder=0)

        #plt.show()

    t = np.arange(time[0], time[-1], cadence)
    bad_points = np.isnan(flux)

    iflux = interp1d( time[~bad_points], 
                     flux[~bad_points]  ,
                     fill_value="extrapolate", 
                     bounds_error=False)
    fl = iflux(t)

    umcut = fl < np.percentile(fl,99.9)
    fl/=np.mean(fl[umcut])
    fl=(fl-1.) / np.ptp(fl[umcut])
    
    iflag = interp1d( time[~bad_points], is_flare[~bad_points], fill_value="extrapolate", bounds_error=False)
    filtflag = medfilt(  medfilt(iflag(t), 3), 9  )
    return( np.array( [t, fl, np.round( filtflag )  ] ) )


# Make sure to check this block!

The train set is defined in `train.txt`. If you want to use only a part of the files listed here for whatever reason, use the `filenumber` option. 

If you want to do a K-fold test, change the value of `split` to 0.8 for K=5, for example. In this case, also modify `np.arange(1)` to `np.arange(5)` in the next block.

In [None]:
import glob
def LoadLC(filenumber=900, split=0.99, kfold_n=0, trv_test_split=0.9):
    from math import floor, ceil
    from numpy import genfromtxt, nanmedian
  
    train_data = []
    validation_data = []
    time=flux=flag=np.array([])
    vtime=vflux=vflag=np.array([])


    with open ("train.txt", "r") as trainfile:
        trainlist = trainfile.read().splitlines()

    train_files = trainlist
    validation_files = trainlist[-5:]            

    print("Train: ")
    for i in train_files:
        print(i)

        t, f, fl = read_data(i)
        time = np.concatenate((time,t))
        flux = np.concatenate((flux,f))
        flag = np.concatenate((flag,fl))


    print("Validation: ")
    for i in validation_files:
        print(i)

        t, f, fl = read_data(i)
        vtime = np.concatenate((vtime,t))
        vflux = np.concatenate((vflux,f))
        vflag = np.concatenate((vflag,fl))


    train_data = np.array([time,flux,flag])
    validation_data = np.array([vtime,vflux,vflag])

    return train_data, validation_data





This will start the training from scratch with a three-layered LSTM(128) network. For transfer learning, comment out the model definition, and uncomment the `load_model` line.

In [None]:
for kfold_ni in np.arange(1):

    train_data, validation_data = LoadLC(kfold_n=kfold_ni)
  
    X_train, y_train = np.array([train_data[1]-1+1]).T, np.array(train_data[2]).reshape(1, train_data[2].size)

    X_test, y_test = np.array([validation_data[1]-1+1]).T, np.array(validation_data[2]).reshape(1, validation_data[2].size)
    
    

    window_size = 64

    batch=1024
    tf.keras.backend.clear_session()

    generator = SplitGenerator(X_train, y_train.T, length=window_size, batch_size=batch)

    generator_val = SplitGenerator(X_test, y_test.T, length=window_size, batch_size=batch)

    units=128
    model = Sequential()

    model.add( (LSTM(units, return_sequences=True)) )
    model.add(Dropout(0.2))
    model.add( (LSTM(units, return_sequences=True)) )
    model.add(Dropout(0.2))
    model.add( (LSTM(units)) )
    model.add( Dense(1, activation='sigmoid'))


    #model=tf.keras.models.load_model('./LSTM-fold_all-mixedtrain0.h5')

    recall_metric = tf.keras.metrics.Recall()
    model.compile(loss='binary_crossentropy', optimizer='nadam', metrics=['accuracy', recall_metric ])

    tensorboard_callback = TensorBoard(log_dir='./logs/LSTM-02-fold_all_mixedtrain'+str(kfold_ni)+str(units)+strftime("%Y-%m-%d.%H:%M:%S", localtime()), histogram_freq=1)
    earlystop_callback = EarlyStopping(monitor='loss', mode='min', min_delta=1e-2, patience=10, restore_best_weights=True)
    checkpoint_file = 'checkpoints/'+strftime("%Y-%m-%d.%H:%M:%S", localtime())+'run_lstm_02-fold_all_mixedtrain'+str(kfold_ni)+'-{epoch:02d}-{accuracy:.2f}.h5'
    checkpoint_callback = ModelCheckpoint(checkpoint_file, save_best_only=True,
                                         verbose=1)



    flare_weight = y_train.size/y_train.sum()

    history = model.fit(generator, 
                        epochs=100,
                        class_weight={0:1, 1:flare_weight}, 
                        validation_data=generator_val,
                        callbacks=[tensorboard_callback, 
                                   earlystop_callback,
                                   checkpoint_callback                               
                                  ])

    print('Done.')

    model.save('LSTM-fold_all-mixedtrain'+str(kfold_ni)+'.h5')
    print('Model saved.')