In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0


In [26]:
import numpy as np
import jax.numpy as jnp
import pylab as plt
import pandas as pd
from pathlib import Path
from astropy.io import fits
from astropy import units as u
# from sklearn.metrics import r2_score
from astropy import wcs, coordinates
# from matplotlib.colors import LogNorm
from scipy.ndimage import gaussian_filter 
from mpl_toolkits.axes_grid1 import make_axes_locatable
from jax.lib import xla_bridge

from utils import *

get_available_gpus()
print(xla_bridge.get_backend().platform)

['/device:GPU:0']
Default GPU Device:/device:GPU:0
gpu


In [28]:
if not in_notebook():
    import argparse
    parser = argparse.ArgumentParser(description='MODEL ACTIVITY ANALYZER.')
    parser.add_argument('--dataset', default='./dataset', type=str, help='path to dataset')
    parser.add_argument('-s', default=32, type=int, help='image length')
    parser.add_argument('-f', default=5, type=int, help='image length')
    parser.add_argument('--fg', default=1, type=int, help='channel grow factor')
    parser.add_argument('--bn', default=0, type=int, help='batch norm')
    parser.add_argument('--act', default='relu', type=str, help='activation')
    parser.add_argument('--cinc', default=0, type=int, help='continuum include?')
    parser.add_argument('--BS', default=32, type=int, help='batch norm')
    parser.add_argument('--epochs', default=10, type=int, help='batch norm')
#     parser.add_argument('--model', default='model file name', type=str, help='model file name')
#     parser.add_argument('--bn', default=0, type=int, help='image length')
#     parser.add_argument('--prefix', default='', type=str, help='path to save the results')
#     parser.add_argument('--restart', action="store_true")

    args = parser.parse_args()
    data_path = args.dataset
    
    nd1 = args.s
    nd2 = args.s
    nch = args.f
    fgrow = args.fg
    bnorm = args.bn
    lactivation = args.act
    cinclude = args.cinc
#     restart = args.
    EPOCHS = args.epochs
    BS = args.BS
    
#     restart = args.restart

else:
    data_path = '/home/vafaeisa/scratch/ska/development/'
    data_path = '/home/vafaeisa/scratch/ska/development_large/'
#     data_path = '/home/vafaeisa/scratch/ska/evaluation/'
    nd1,nch = 32,5
    nd2 = nd1
    fgrow = 1.5
    bnorm = 0
    lactivation = 'relu'
    cinclude = 1
    restart = 0
    EPOCHS = 1
    BS = 32

ds = 5
dff = 60
dsmear = 10

if 'development_large' in data_path:
    dmode = 'ldev'
elif 'development' in data_path:
    dmode = 'dev'
elif 'evaluation' in data_path:
    dmode = 'eval'
else:
    assert 0,'dmod error!'

mname = 's{}-f{}-fg{}-bn{}-{}-c{}/'.format(nd1,nch,fgrow,bnorm,lactivation,cinclude)
    
Path('models').mkdir(parents=True, exist_ok=True)
Path('models/'+mname).mkdir(parents=True, exist_ok=True)

mpath = 'models/'+mname
model_name = mpath+'model.h5'

In [29]:
try:
    model = keras.models.load_model(model_name)
except:
    if cinclude:
        model = SimpleConv_count(shape1=(nd1,nd2,nch),shape2=(nd1,nd2,21),n_class=nch,fgrow=fgrow,bnorm=bnorm,lactivation=lactivation)
    else:
        model = SimpleConv(shape=(nd1,nd2,nch),n_class=nch,fgrow=fgrow,bnorm=bnorm,lactivation=lactivation)

model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 5)]  0                                            
__________________________________________________________________________________________________
input_2 (InputLayer)            [(None, 32, 32, 21)] 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 32, 32, 7)    322         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 32, 32, 7)    1330        input_2[0][0]                    
______________________________________________________________________________________________

In [31]:
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
                             initial_learning_rate=1e-3,
                             decay_steps=50,
                             decay_rate=0.95)
opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

model.compile(
#     loss=keras.losses.BinaryCrossentropy(),
    loss=keras.losses.MeanSquaredError(),
    optimizer=opt,
#     metrics=["accuracy"],
)

In [33]:
# ls {data_path}

In [34]:
sky = fits.open(data_path+'sky_{}_v2.fits'.format(dmode))
if cinclude:
    cont = fits.open(data_path+'cont_{}.fits'.format(dmode))

header = sky[0].header
sources = pd.read_csv(data_path+'sky_{}_truthcat_v2.txt'.format(dmode),delimiter=' ')
sources = sources.sort_values('line_flux_integral',ascending=0).reset_index()

dfreq = header['CDELT3']
freq0 = header['CRVAL3']
# # frequency list in the HI cube
nf,nx,ny = sky[0].data.shape
freqs = np.arange(freq0,freq0+nf*dfreq,dfreq)
fqmhz = freqs/1e6

coord_sys = wcs.WCS(header)
ra, dec = sources['ra'],sources['dec']
num_sources = len(ra)
radec_coords = coordinates.SkyCoord(ra=ra, dec=dec, unit='deg', frame='fk5')
coords_ar = np.vstack([radec_coords.ra*u.deg, radec_coords.dec*u.deg,
                         np.zeros(num_sources)]).T
xy_coords = coord_sys.wcs_world2pix(coords_ar, 0)
x_coords, y_coords = xy_coords[:,0], xy_coords[:,1]
f_coordsf = sources['central_freq']

flux_inds = np.argsort(sources['line_flux_integral'].values)[::-1]

delta = 1
psky = sky[0].data
# smoothed_sky = gaussian_filter(sky[0].data,sigma=(3,3,5))
print(psky.shape)



(6668, 1286, 1286)


In [None]:
n_epochs = 2

dc = 400

edges = np.linspace(0,nx-dc,nx//dc).astype(int)
edges[-1] = nx-dc

gpatch = get_patch((2*dff+1,2*ds+1,2*ds+1),axis=2,sigma0=0.55,muu=0,c=0.4)
chunks = np.arange(0,1286,dc)
nchunk = len(chunks)

for epoch in range(n_epochs):
    ichunk = 0
    for icube in chunks:
        for jcube in chunks:
#             print(icube,icube+dc,jcube,jcube+dc)

            i1,i2,j1,j2 = icube,icube+dc,jcube,jcube+dc

            psky = sky[0].data[:,i1:i2,j1:j2]
        
            csky = None
            if cinclude:
                csky = cont[0].data[:,i1:i2,j1:j2]
        
            ysky = 0.2*psky

            for ii in range(sources.shape[0]):
#                 print(ii,end='\r')
                source = sources.loc[flux_inds[ii]:flux_inds[ii],:]
                ra_s = source['ra'].values
                dec_s = source['dec'].values
                freq_s = source['central_freq'].values
                ra_p,dec_p,freq_p = coord_sys.wcs_world2pix(np.array([ra_s,dec_s,freq_s]).reshape(1,3),0).T.astype(int)
                ra_p,dec_p,freq_p = ra_p[0],dec_p[0],freq_p[0]

                if i1<ra_p-ds and ra_p+ds<i2 and j1<dec_p-ds and dec_p+ds<j2:
                    pass
                else:
                    continue

                patch = psky[freq_p-dff:freq_p+dff+1,dec_p-ds:dec_p+ds+1,ra_p-ds:ra_p+ds+1]+0

                try:
                    i = patch
                    c = gpatch
                    patch = 2*gpatch*patch
                    ysky[freq_p-dff:freq_p+dff+1,dec_p-ds:dec_p+ds+1,ra_p-ds:ra_p+ds+1] += patch
                except:
                    pass

            psky = psky-psky.min()
            psky = psky/psky.max()
#             psky = np.concatenate([np.zeros((1,dc,dc)),psky,np.zeros((1,dc,dc))],axis=0)
            psky = smear(psky,0,dsmear)

            ysky = ysky-ysky.min()
            ysky = ysky/ysky.max()
#             ysky = np.concatenate([np.zeros((1,dc,dc)),ysky,np.zeros((1,dc,dc))],axis=0)
            ysky = smear(ysky,0,dsmear)

            psky = psky-psky.min()
            psky = psky/psky.max()
            ysky = ysky-ysky.min()
            ysky = ysky/ysky.max()

            psky = np.swapaxes(psky,2,0)
            ysky = np.swapaxes(ysky,2,0)
            if cinclude:
                csky = np.swapaxes(csky,2,0)
                csky = csky-csky.min()
                csky = csky/csky.max()

            if cinclude:
                def data_provider(n):
                    x1,x2,y = [],[],[]
                    for i in range(n):
                        xp,xc,yp = get_slice(psky,ysky,nd1,nd2,nch,data2=csky)
                        x1.append(xp)
                        x2.append(xc)
                        y.append(yp)
                    x1 = np.array(x1)
                    x2 = np.array(x2)
                    y = np.array(y)
                    return x1,x2,y
            else:
                def data_provider(n):
                    x,y = [],[]
                    for i in range(n):
                        xp,yp = get_slice(psky,ysky,nd1,nd2,nch)
                        x.append(xp)
                        y.append(yp)
                    x = np.array(x)
                    y = np.array(y)
                    return x,y

            n_iter = psky.size//(nd1*nd2*nch)
            
            losses = []
            for i in range(n_iter):
                
                if cinclude:
                    x1,x2,y = data_provider(10)
                    loss = model.train_on_batch([x1,x2],y)
                else:
                    x,y = data_provider(10)
                    loss = model.train_on_batch(x,y)
                
                rept = '| iter: {:5.3f}%   <{:50s}> | loss={:4.2f} | chunk {:3d}/{:3d} | epochs: {:3d}/{:3d} |'
                report = rept.format(100*(i+1)/n_iter,int(50*(i+1)/n_iter)*'=',
                                     np.mean(loss),
                                     ichunk,nchunk,
                                     epoch+1,n_epochs)
                print(report,end='\r')
                
model.save(model_name)

