In [1]:
cd /home/makinen/repositories/alfi_lensing/imnn_scripts/

/home/makinen/repositories/alfi_lensing/imnn_scripts


In [2]:
import cloudpickle as pickle
import h5py as h5
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
import matplotlib.pyplot as plt
import numpy as onp

from utils import rotate_sim
from nets import *
from imnn_mod import *

import json
import sys,os
import gc

In [3]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [5]:
configs = {
    "datadir": "/data80/makinen/borg_sims_fixed/training_data/",
    "savedir": "/data80/makinen/borg_sims_fixed/imnn_results/twoparam/retrain/",
    "modeldir": "/data80/makinen/borg_sims_fixed/imnn_results/twoparam/",
    "priordir": "/data80/makinen/borg_sims_fixed/prior_data/",
    "target_path": "/data80/nporqueres/borg_sims_fixed/mock_data/",
    "patience": 2000,
    "filters":  20,
    "plotdir": "/home/makinen/repositories/IMNN_vs_BORG/plots/",
    "do_noise": 1,
    "noise_scale": 1,
    "net_scaling": 0.005,
    "act": "almost_leaky",

    "borg_data_configs": { 
        "fiducial_path": "/data80/nporqueres/borg_sims_fixed/fiducial/",
        "omegaM_path": "/data80/nporqueres/borg_sims_fixed/omegaM/",
        "sigma8_path": "/data80/nporqueres/borg_sims_fixed/sigma8/",
        "prior_path": "/data80/nporqueres/borg_sims_fixed/new_uniform_prior/",
        "omegaM_stepsize": 0.05,
        "sigma8_stepsize": 0.015
    }
}

In [8]:
def save_obj(obj, name ):
    with open(name + '.pkl', 'wb') as f:
        pickle.dump(obj, f)
        
def load_obj(name):
    with open(name, 'rb') as f:
        return pickle.load(f)

def reshape_data(dat):
    realpart = dat[:, ::2, :, :]
    impart = dat[:, 1::2, :, :]

    return jnp.stack([realpart, impart], axis=-1)

savedir    = configs["savedir"]     # where to save shit
datadir    = configs["datadir"]     # where shit is saved
modeldir   = configs["modeldir"]
priordir   = configs["priordir"]
do_noise   = bool(configs["do_noise"])


### ------------- LOAD ALL DATA -------------
print("loading data, do noise: ", do_noise)
key = jax.random.PRNGKey(33)
key,rng = jax.random.split(key)

fid = jnp.load(datadir + 'noisefree_fid.npy')
val_fid = jnp.load(datadir + 'noisefree_val_fid.npy')
dervs = jnp.load(datadir + 'noisefree_derv.npy')
val_dervs = jnp.load(datadir + 'noisefree_val_derv.npy')

# concatenate existing derivatives
om_s8_dervs = onp.concatenate([dervs, val_dervs])



# create a set of amplitude derivatives
A = 1.0 # fiducial
Aminus = 0.99
Aplus = 1.01

# take some fiducial simulations and change the amplitude
Aplus = fid[:500]*Aplus
Aminus = fid[:500]*Aminus

# for h0; ordering: h0-, h0+, h0-, ...
full_derivatives = onp.ones((3000, 8, 64, 64))

# Om-
full_derivatives[::6] = om_s8_dervs[0::4]
# s8-
full_derivatives[1::6] = om_s8_dervs[1::4]
# A-
full_derivatives[2::6] = Aminus

# Om+
full_derivatives[3::6] = om_s8_dervs[2::4]
# s8+
full_derivatives[4::6] = om_s8_dervs[3::4]
# A+
full_derivatives[5::6] = Aplus

# do some random shuffling of the data
# swap some sims around from train and validation
fid1 = jnp.concatenate([val_fid, fid], axis=0)
idx = jax.random.shuffle(key, jnp.arange(2000), axis=0)

val_fid = jnp.array(fid1[idx[1000:]])
fid = jnp.array(fid1[idx[:1000]]); del fid1

# now derivatives
dervs1 = full_derivatives

# chunk into groups of 6 seed-matched groups: (om-,s8-,h0-,om+,s8+,h0+)
dervs1 = jnp.array(jnp.split(dervs1, 500))
print('split dervs shape', dervs1.shape)
idx = jax.random.shuffle(rng, jnp.arange(500), axis=0) # new random seed index

val_dervs = jnp.concatenate(dervs1[idx[:250]])
dervs = jnp.concatenate(dervs1[idx[250:]]); del dervs1,full_derivatives

fid = (reshape_data(fid))
val_fid = (reshape_data(val_fid))
dervs = (reshape_data(dervs))
val_dervs = (reshape_data(val_dervs))

fid.shape, val_fid.shape, dervs.shape, val_dervs.shape

loading data, do noise:  True




split dervs shape (500, 6, 8, 64, 64)


((1000, 4, 64, 64, 2),
 (1000, 4, 64, 64, 2),
 (1500, 4, 64, 64, 2),
 (1500, 4, 64, 64, 2))

In [9]:
### ------------- IMNN PARAMETERS -------------
θ_fid = jnp.array([0.3175, 0.800, 1.0]) # CHANGE TO OmegaM=0.6
δθ = 2*jnp.array([0.05, 0.015, 0.01])

θ_der = (θ_fid + jnp.einsum("i,jk->ijk", jnp.array([-1., 1.]), jnp.diag(δθ) / 2.)).reshape((-1, 2))

n_summaries = 3

n_s = 1000
n_d = 250

n_params = 3
n_summaries = n_params

In [10]:
class InceptBlock3D(nn.Module):
    """Inception block submodule"""
    filters: Sequence[int]
    filters_reduce: Sequence[int]
    strides: Union[None, int, Sequence[int]]
    dims: int
    do_5x5: bool = True
    do_3x3: bool = True
    #input_shape: Sequence[int]

    @nn.compact
    def __call__(self, x):

        f_red = self.filters_reduce

        outs = []
        
        if self.do_5x5:
        # 5x5 filter
         x1 = nn.Conv(features=f_red[0], kernel_size=(1,)*self.dims, strides=None)(x)
         #x1 = nn.Conv(features=self.filters[0], kernel_size=(1,5,5), strides=None)(x1)
         x1 = nn.Conv(features=self.filters[0], kernel_size=(3,5,5), strides=self.strides)(x1)
         outs.append(x1)
          
        if self.do_3x3:
        # 3x3 filter
          x2 = nn.Conv(features=f_red[1], kernel_size=(1,)*self.dims, strides=None)(x)
          x2 = nn.Conv(features=self.filters[1], kernel_size=(3,3,3), strides=self.strides)(x2)
          #x2 = nn.Conv(features=self.filters[1], kernel_size=(1,3,3), strides=self.strides)(x2)
          outs.append(x2)

        # 1x1
        #x3 = nn.Conv(features=f_red[2], kernel_size=(1,)*self.dims, strides=None)(x)
        x3 = nn.Conv(features=self.filters[2], kernel_size=(3,)*self.dims, strides=self.strides)(x)
        outs.append(x3)
        
        # maxpool and avgpool
        x4 = nn.max_pool(x, (3,)*self.dims, padding='SAME', strides=self.strides)
        #x4 = nn.Conv(features=self.filters[3], kernel_size=(1,)*self.dims, strides=self.strides)(x4)
        outs.append(x4)
                    
        x = jnp.concatenate(outs, axis=-1)
        
        return x    

In [17]:
class CNN3D(nn.Module):
    """An incept net architecture"""
    filters: Sequence[int] = (10,10,10,10) #(5,5,5,5) #(2,2,2,2)
    filters_reduce: Sequence[int] = (3, 3, 3, 3)
    div_factor: float = 0.005
    out_shape: int = 2
    do_big_convs: bool = True
    act: str = "gelu"
    
    @nn.compact
    def __call__(self, x):
        fs = self.filters
        fs_red = self.filters_reduce
        dbg = self.do_big_convs
        if self.act == "almost_leaky":
           act = almost_leaky
        else:
           act = nn.gelu

        x /= self.div_factor
        x = InceptBlock3D(fs, fs_red, strides=(1,2,2), dims=3, do_5x5=True)(x) # out: 4, 32, 32, 2
        x = act(x)
        #fs *= 2
        x = InceptBlock3D(fs, fs_red, strides=(1,2,2), dims=3, do_5x5=False)(x) # out: 4, 16, 16, 2
        x = act(x)
        #fs *= 4
        x = InceptBlock3D(fs, fs_red, strides=(1,2,2), dims=3, do_5x5=False)(x) # out: 4, 8, 8, 2
        x = act(x)
        #fs *= 2
        x = InceptBlock3D(fs, fs_red,  strides=(1,2,2), dims=3, do_5x5=False)(x) # out: 4, 4, 4, 2
        x = act(x)
        #fs *= 4
        x = InceptBlock3D(fs, fs_red,  strides=(2,2,2), dims=3, do_5x5=False)(x) # out: 1, 1, 1, 2
        x = act(x)
        #fs *= 2
        x = InceptBlock3D(fs, fs_red,  strides=(2,2,2), dims=3, do_5x5=False, do_3x3=False)(x) # out: 1, 1, 1, 2
        x = act(x)
        x = nn.Conv(self.out_shape, (1,)*3, 1)(x)
        x = x.reshape(-1)
        
        return x

In [20]:
64 / 4

16.0

In [27]:
### ------------- NEURAL NETWORK MODEL -------------

filters = (int(configs["filters"]),)*4
#filters_reduce = configs["filters_reduce"] #(int(configs["filters_reduce"]),)*4
net_scaling = float(configs["net_scaling"])
patience = configs["patience"]
noise_scale = configs["noise_scale"]
act = configs["act"]

if not do_noise:
    net_scaling /= 10.
else:
    net_scaling *= noise_scale

    
model = CNN3D(filters=filters,
                 div_factor=net_scaling, out_shape=3, act=act)
key = jax.random.PRNGKey(42)

input_shape = (4, 64, 64, 2)

### ------------- DEFINE DATA AUGMENTATION SCHEME -------------
### ADD IN NOISE ON TOP OF FIELD SIMS
### NOISE VARIANCES == SIGMA^2

#if do_noise:
noise_variances = jnp.array([1.79560224e-06, 5.44858988e-06, 9.45448781e-06, 1.32736252e-05])

In [28]:
w = model.init(key, jnp.ones((4,64,64,2)))

_app = lambda d: model.apply(w, d)

jax.vmap(_app)(fid).shape

(1000, 3)

In [22]:
#@jax.jit
def noise_simulator(key, sim):
    key1,key2 = jax.random.split(key)
    # do rotations of simulations
    k = jax.random.choice(key1, jnp.array([0,1,2,3]), shape=())
    sim = rotate_sim(k, sim)

    # now add noise
    # this generates white noise across all pixels and then increases the amplitude
    # add zero noise for no-noise case
    sim += (jax.random.normal(key2, shape=(4,64,64,2)) * noise_scale * jnp.sqrt(noise_variances).reshape(4,1,1,1))
    return sim

In [23]:
#### ------------- SET UP IMNN -------------

optimiser = optax.adam(learning_rate=1e-4)

model_key = jax.random.PRNGKey(42)
rng, key = jax.random.split(key)

In [26]:
n_params

3

In [29]:
IMNN = NoiseNumericalGradientIMNN(
    n_s=n_s, n_d=n_d, n_params=n_params, n_summaries=n_summaries,
    input_shape=(4, 64, 64, 2), θ_fid=θ_fid, δθ=δθ, model=model,
    optimiser=optimiser, key_or_state=jnp.array(model_key),
    noise_simulator=(lambda rng, d: noise_simulator(
            rng, d)),
    fiducial=fid, 
    derivative=dervs.reshape(n_d, 2, n_params, 4, 64, 64, 2),
    validation_fiducial=val_fid,
    validation_derivative=val_dervs.reshape(n_d, 2, n_params, 4, 64, 64, 2),
    dummy_graph_input=None,  # dummy graph input
    no_invC=False,
    do_reg=True,
    evidence=False)

In [15]:
%xmode verbose

Exception reporting mode: Verbose


In [30]:
## ------------- TRAIN THE IMNN -------------
np=jnp

gc.collect()
key,rng = jax.random.split(key)

IMNN.fit(10.0, 0.1, γ=1.0, rng=jnp.array(rng), patience=100, min_iterations=2000)

In [21]:
IMNN.F

In [None]:



print('final IMNN F: ', IMNN.F)
print('final IMNN det F: ', jnp.linalg.det(IMNN.F))

save_obj(IMNN.w, savedir + 'IMNN_w')
jnp.save(savedir + "IMNN_F", IMNN.F)
# convert history dict to onp arrays
history = IMNN.history
for k in history.keys():
  history[k] = onp.array(history[k])

save_obj(history, savedir + 'IMNN_history')

### ------------- PASS PRIOR DATA THROUGH IMNN FUNNEL -------------
np = jnp
#dat = jnp.load("/data80/makinen/borg_sims_fixed/uniform_prior_sims/noisefree_prior_sims.npy")
#params = jnp.load("/data80/makinen/borg_sims_fixed/uniform_prior_sims/noisefree_prior_params.npy")
dat = jnp.load(priordir + "prior_sims_noisefree.npy")
params = jnp.load(priordir + "prior_params.npy")


dat = reshape_data(dat)[:, :, :, :, :]

# ADD NOISE TO DATA
noisekey = jax.random.PRNGKey(11)
noisekeys = jax.random.split(noisekey, num=dat.shape[0])
dat = jax.vmap(noise_simulator)(noisekeys, dat)

x1 = IMNN.get_estimate(dat[:2500])
x2 = IMNN.get_estimate(dat[2500:])
x = jnp.concatenate([x1,x2])

x1 = IMNN.get_estimate(dat[:2500])
x2 = IMNN.get_estimate(dat[2500:])
x = jnp.concatenate([x1,x2])


# save the prior sims' mapping
jnp.save(savedir + "x_imnn.npy", x)
jnp.save(savedir + "theta.npy", params)


#### ------------- get Natalia's target data WITH NOISE -------------
print("loading mock data for obtaining estimates")

import numpy as onp
np = onp
import h5py as h5

def get_data(Ncat,N0,N1,f):
    dataR = onp.zeros((Ncat,N0,N1))
    dataI = onp.zeros((Ncat,N0,N1))
    
    for cat in range(0,Ncat,1):
        survey = f['lensing_catalog_'+str(cat)]['lensing_data']['lensing'][:]
        Nobs = len(survey)
        N0 = f['scalars/N0'][0]
        N1 = f['scalars/N1'][0]
        
        count = np.zeros((N0,N1))
        for nobs in range(0,Nobs,1):
            lens = survey[nobs]
            n0 = int(lens['phi'])
            n1 = int(lens['theta'])
            dataR[cat,n0,n1] = lens['shearR']
            dataI[cat,n0,n1] = lens['shearI']
            count[n0,n1] += 1
    return dataR, dataI


#path = '/data80/nporqueres/borg_sims_fixed/'

path = configs["target_path"]
f = h5.File(path + 'mock_data.h5', 'r')

Ncat = f['scalars/NCAT'][0]

N0 = f['scalars/N0'][0]
N1 = f['scalars/N1'][0]
N2 = f['scalars/N2'][0]

L0 = f['scalars/L0'][0]
L1 = f['scalars/L1'][0]
L2 = f['scalars/L2'][0]

targetR, targetI = get_data(Ncat, N0, N1, f)

_dat = onp.ones((8, 64, 64))
_dat[::2, :, :] = targetR
_dat[1::2, :, :] = targetI

np=jnp
_dat = jnp.array(_dat)
target_data = jnp.squeeze(reshape_data(_dat[jnp.newaxis, :, :, :]))

estimates = IMNN.get_estimate(jnp.expand_dims(target_data, 0))

print('IMNN estimates for target sim', estimates)

jnp.save(savedir + 'estimates', estimates)
