In [None]:
%pylab inline
%load_ext autoreload
%autoreload 2
import os
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import sys
sys.path.insert(0, '..')
import tensorflow as tf
import tf_slim as slim
from graphgan.utils import get_3d_direction
from graphgan.layers import *
from graphgan.gradient_penalty import gradient_penaly
from graphgan.datasets import graph_input_fn
from graphgan.datasets import project_ellipticities, project_3d_shape, project_ellipticities_np
from astropy.table import Table, join
from functools import partial
from halotools_ia.correlation_functions  import ed_3d,ee_3d, ed_3d_one_two_halo_decomp

from sklearn.preprocessing import RobustScaler
import tensorflow_probability as tfp
import pickle
from pandas import *
import warnings
warnings.filterwarnings('ignore')

print(" Available: ",  (tf.config.list_physical_devices('GPU')))

In [None]:
#Load the catalog and do some pre-processing,

tng = pickle.load(  open('/hildafs/projects/phy200017p/yjagvara/some_data/TNG100-1_99_non-reduced_galaxy_shapes_multi_scale_1024_MLP_only_cent.pkl', "rb" ) )

tng = tng[tng['dm_mass']>0]
tng = tng[[log10(tng['dm_mass']*10**10)>9]]
tng = tng[[log10(tng['mass']*10**10)>9]]
tng['q'] = tng['b']/tng['a']
tng['s'] = tng['c']/tng['a']

In [None]:
def project_shape(a3d, b3d, c3d, q3d, s3d):
 
    
    s = tf.stack([a3d, b3d, c3d])
     
    w = tf.stack([tf.ones_like(q3d), q3d, s3d])
 

    k = tf.reduce_sum(s[:,:,0:2]*tf.expand_dims(s[:,:,2], axis=-2) / tf.expand_dims(w[:,:]**2, axis=-2), axis=0)
    a2 =tf.reduce_sum(s[:,:,2]**2/w[:,:]**2, axis=0)
 
    first_term = tf.reduce_sum(tf.einsum('ijko,ijlo->ijklo', s[:,:,0:2,...], s[:,:,0:2,...]) / tf.expand_dims(tf.expand_dims(w[:,:]**2,-2),-2), axis=0)
    second_term = tf.einsum('ijo,iko->ijko', k,k)/tf.expand_dims(tf.expand_dims(a2,-2),-2)
    Winv = first_term - second_term
    
    W = tf.linalg.inv(tf.squeeze( tf.transpose(Winv) ))
    d = tf.sqrt(tf.linalg.det(W))
    denom = ( W[:,0,0] + W[:,1,1] + 2*d)
    num_1 = (W[:,0,0] - W[:,1,1])
    
    e1 = num_1/denom
    e2 = 2 * W[:,0,1]/denom
 
    return tf.stack([e1, e2], axis=1)

In [None]:
#Now we will get the 2D vectors of interest
a3d = np.array([[tng['av_x'] , tng['av_y'] , tng['av_z'] ]])
b3d = np.array([[tng['bv_x'] , tng['bv_y'] , tng['bv_z'] ]])
c3d = np.array([[tng['cv_x'] , tng['cv_y'] , tng['cv_z'] ]])
q3d = np.array([tng['b'] /tng['a'] ])
s3d = np.array([tng['c'] /tng['a'] ])
    
e12_modif = project_3d_shape(a3d, b3d, c3d, q3d, s3d)
 
tng['e1'] = e12_modif[:,0]
tng['e2'] = e12_modif[:,1]


In [None]:
# Reoriente all galaxies with respect to the tidal field
# Pre-processing the orientation of galaxies with respect to their host haloes and the tidal field

# Computes the size of groups
gids, idx, inv, counts  = np.unique(tng['GroupID'],  return_index=True, return_inverse=True, return_counts=True)
tng['group_size'] = counts[inv]

# Convert distances to Mpc
# tng['gal_pos_x'] /= 1000.
# tng['gal_pos_y'] /= 1000.
# tng['gal_pos_z'] /= 1000.

tng['group_x'] /= 1000.
tng['group_y'] /= 1000.
tng['group_z'] /= 1000.

# Computes direction to the central
tng['cen_x'] = tng['group_x'][idx][inv] - tng['gal_pos_x']
tng['cen_y'] = tng['group_y'][idx][inv] - tng['gal_pos_y']
tng['cen_z'] = tng['group_z'][idx][inv] - tng['gal_pos_z']
ncen = np.sqrt(tng['cen_x']**2 + tng['cen_y']**2 + tng['cen_z']**2 ) 
tng['cen_r'] = ncen

inds_cent = ncen == 0
ncen[ncen == 0] = 1
tng['cen_x'] = tng['cen_x']/ncen
tng['cen_y'] = tng['cen_y']/ncen
tng['cen_z'] = tng['cen_z']/ncen

# # First reorienting the tidal 
a = (tng['tid_av_x_0.1_1024']*tng['cen_x'] +
     tng['tid_av_y_0.1_1024']*tng['cen_y'] +
     tng['tid_av_z_0.1_1024']*tng['cen_z'])

# According to the sign, decide to reverse the orientation of
flip_a = ones_like(a)
flip_a[where(a < 0)] *= -1.0
tng['tid_av_x_0.1_1024'] *= flip_a
tng['tid_av_y_0.1_1024'] *= flip_a
tng['tid_av_z_0.1_1024'] *= flip_a
tng['tid_bv_x_0.1_1024'] *= flip_a
tng['tid_bv_y_0.1_1024'] *= flip_a
tng['tid_bv_z_0.1_1024'] *= flip_a

# Computing angle with respect to the tidal field,
# adjusting the axes to have the same orientation
aTid = (tng['dm_av_x']*tng['tid_cv_x_0.1_1024'] +
        tng['dm_av_y']*tng['tid_cv_y_0.1_1024'] +
        tng['dm_av_z']*tng['tid_cv_z_0.1_1024'])
bTid = (tng['dm_bv_x']*tng['tid_cv_x_0.1_1024'] +
        tng['dm_bv_y']*tng['tid_cv_y_0.1_1024'] +
        tng['dm_bv_z']*tng['tid_cv_z_0.1_1024'])
cTid = (tng['dm_cv_x']*tng['tid_cv_x_0.1_1024'] +
        tng['dm_cv_y']*tng['tid_cv_y_0.1_1024'] +
        tng['dm_cv_z']*tng['tid_cv_z_0.1_1024'])
caTid = (tng['dm_cv_x']*tng['tid_av_x_0.1_1024'] +
         tng['dm_cv_y']*tng['tid_av_y_0.1_1024'] +
         tng['dm_cv_z']*tng['tid_av_z_0.1_1024'])

# According to the sign, decide to reverse the orientation of
# the dark matter halo by rotating around b or c
flip_a = ones_like(aTid)
flip_a[where(aTid < 0)] *= -1.0
flip_c = ones_like(cTid)
flip_c[where(caTid < 0)] *= -1.0
flip_b = ones_like(bTid)
flip_b = flip_a * flip_c

# Apply rotation around c, thus preserving the sign of c
aTid *= flip_a
bTid *= flip_b
cTid *= flip_c

# Update the DM halo orientation
tng['dm_av_x'] *= flip_a
tng['dm_av_y'] *= flip_a
tng['dm_av_z'] *= flip_a
tng['dm_bv_x'] *= flip_b
tng['dm_bv_y'] *= flip_b
tng['dm_bv_z'] *= flip_b
tng['dm_cv_x'] *= flip_c
tng['dm_cv_y'] *= flip_c
tng['dm_cv_z'] *= flip_c

# Compute misalignment of stellar component in same rotated frame
a = (tng['dm_av_x']*tng['av_x'] +
     tng['dm_av_y']*tng['av_y'] +
     tng['dm_av_z']*tng['av_z'])
b = (tng['dm_bv_x']*tng['av_x'] +
     tng['dm_bv_y']*tng['av_y'] +
     tng['dm_bv_z']*tng['av_z'])
c = (tng['dm_cv_x']*tng['av_x'] +
     tng['dm_cv_y']*tng['av_y'] +
     tng['dm_cv_z']*tng['av_z'])
cc = (tng['dm_cv_x']*tng['cv_x'] +
      tng['dm_cv_y']*tng['cv_y'] +
      tng['dm_cv_z']*tng['cv_z'])

# Apply rotation to the Stellar shape frame to match DM frame
flip_a_stel = ones_like(a)
flip_a_stel[where(a < 0)] *= -1.0
flip_c_stel = ones_like(cc)
flip_c_stel[where(cc < 0)] *= -1.0
flip_b_stel = flip_a_stel * flip_c_stel

# Rotation around c, leaving c unchanged
# Rotation around a, leaving a unchanged
a *= flip_a_stel
b *= flip_b_stel
c *= flip_c_stel

# Update the Stellar halo orientation
tng['av_x'] *= flip_a_stel
tng['av_y'] *= flip_a_stel
tng['av_z'] *= flip_a_stel
tng['bv_x'] *= flip_b_stel
tng['bv_y'] *= flip_b_stel
tng['bv_z'] *= flip_b_stel
tng['cv_x'] *= flip_c_stel
tng['cv_y'] *= flip_c_stel
tng['cv_z'] *= flip_c_stel

In [None]:
# Doing some preprocessing
tng['tot_mass'] = log10(tng['tot_mass']*1e10)
tng['group_mass'] = log10(tng['group_mass']*1e10)
tng['mass'] = log10(tng['mass']*1e10)#*1e10+1)
tng['dm_mass'] = log10(tng['dm_mass']*1e10)#*1e10+1)

In [None]:
catalog = tng

In [None]:
catalog['dm_mass_scaled'] = clip(RobustScaler().fit_transform(catalog['dm_mass'].reshape((-1,1))),-5,5).squeeze()
catalog['mass_scaled'] = clip(RobustScaler().fit_transform(catalog['mass'].reshape((-1,1))),-5,5).squeeze()
catalog['group_mass_scaled'] = clip(RobustScaler().fit_transform(catalog['group_mass'].reshape((-1,1))),-5,5).squeeze()
catalog['tid_a_0.1_1024_scaled'] = clip(RobustScaler().fit_transform(catalog['tid_a_0.1_1024'].reshape((-1,1))),-5,5).squeeze()
catalog['tid_b_0.1_1024_scaled'] = clip(RobustScaler().fit_transform(catalog['tid_b_0.1_1024'].reshape((-1,1))),-5,5).squeeze()
catalog['tid_c_0.1_1024_scaled'] = clip(RobustScaler().fit_transform(catalog['tid_c_0.1_1024'].reshape((-1,1))),-5,5).squeeze()
catalog['tid_a_0.5_1024_scaled'] = clip(RobustScaler().fit_transform(catalog['tid_a_0.5_1024'].reshape((-1,1))),-5,5).squeeze()
catalog['tid_b_0.5_1024_scaled'] = clip(RobustScaler().fit_transform(catalog['tid_b_0.5_1024'].reshape((-1,1))),-5,5).squeeze()
catalog['tid_c_0.5_1024_scaled'] = clip(RobustScaler().fit_transform(catalog['tid_c_0.5_1024'].reshape((-1,1))),-5,5).squeeze()
catalog['tid_a_1.0_1024_scaled'] = clip(RobustScaler().fit_transform(catalog['tid_a_1.0_1024'].reshape((-1,1))),-5,5).squeeze()
catalog['tid_b_1.0_1024_scaled'] = clip(RobustScaler().fit_transform(catalog['tid_b_1.0_1024'].reshape((-1,1))),-5,5).squeeze()
catalog['tid_c_1.0_1024_scaled'] = clip(RobustScaler().fit_transform(catalog['tid_c_1.0_1024'].reshape((-1,1))),-5,5).squeeze()

In [None]:
weighting='exp'
directions = get_3d_direction()
filter_size=directions.shape[-1]
tf.compat.v1.enable_eager_execution()
# Define generator function
def conditional_generator_fn(inputs, 
                             is_training=True,
                             reuse=None,
                             scope='Generator',
                             fused_batch_norm=False):

    W0, W1, W2, pm0, pm1, pm2, xsp, X, noise = inputs
    # W_i matrices define the sprase matrix to construct adjecency
    adj = tf.SparseTensor(tf.cast(W0, tf.int64),W1,W2)
    # Computes 3D adjacency matrix
    mr = spatial_adjacency(xsp, adj, directions, filter_size, radial_weighting=weighting, radial_scale=0.4,
                           learn_scale=False)

    # Input level, transforming all inputs features into a single channel
    # First we mix the inputs
    net = graph_conv2(tf.concat([noise, X],axis=1), mr, num_outputs=128, activation_fn=tf.nn.leaky_relu)
    net = slim.batch_norm(net,center=True, scale=True)

    net = graph_conv2(net, mr, num_outputs=128,  activation_fn=tf.nn.leaky_relu)
    net = slim.batch_norm(net,center=True, scale=True)

    net = graph_conv2(net, mr, num_outputs=16,  activation_fn=tf.nn.leaky_relu )
    net = slim.batch_norm(net,center=True, scale=True )
    
 
    net_a2d = graph_conv2(net, mr, 2, activation_fn=tf.nn.tanh, one_hop=None, weights_initializer=tf.compat.v1.truncated_normal_initializer(mean= 0.0, stddev=0.001))

    # Assemble output, 
    
    out_net = tf.concat([net_a2d], axis=1)


    return out_net

def conditional_discriminator_fn(y, conditioning):
    """
    Discriminator network that can tell if the galaxies are correctly aligned
    args:
    y: alignment signal
    conditioning: tuple (adj, idn, x_spatial, pool, x, noise)
    """
    W0, W1, W2, pm0, pm1, pm2, xsp, X, noise = conditioning
    #pm_i are the pooling matrices, also sparse
    
    adj = tf.SparseTensor(tf.cast(W0, tf.int64),W1,W2)
    pool = tf.SparseTensor(tf.cast(pm0, tf.int64), pm1, pm2)
    
    # Computes 3D adjacency matrices for each multi-resolution level
    mr = spatial_adjacency(xsp, adj, directions, filter_size, radial_weighting=weighting, learn_scale=False, 
                           radial_scale=0.2)

    if y.get_shape().as_list()[1] > 2: 

        y = project_3d_shape(tf.expand_dims(tf.transpose(y[...,0:3]),axis=0) , tf.expand_dims(tf.transpose(y[...,3:6]),axis=0), 
                        tf.expand_dims(tf.transpose(y[...,6:9]),axis=0), tf.expand_dims(tf.transpose(y[...,9]),axis=0), tf.expand_dims(tf.transpose(y[...,10]),axis=0))
        

    net = graph_conv2(tf.concat([y, X],axis=1), mr, num_outputs=128, activation_fn=tf.nn.leaky_relu)
    net = graph_conv2(net, mr, 128, activation_fn=tf.nn.leaky_relu)
    net = graph_conv2(net, mr, 64, activation_fn=tf.nn.leaky_relu)
    net = graph_conv2(net, mr, 32, activation_fn=tf.nn.leaky_relu)

    # Apply spatial pooling, MeanPooling
    net = tf.compat.v1.sparse_tensor_dense_matmul(pool, net)

    net = slim.fully_connected(net, 1, activation_fn=None)
    
    return net

In [None]:
import tensorflow_gan as tfgan
from tensorflow_gan.python import namedtuples

my_config = tf.estimator.RunConfig(
    save_summary_steps = 500,
    save_checkpoints_steps = 500,
    keep_checkpoint_max = 50000,       # Retain the 10 most recent checkpoints.
)

def silly_custom_discriminator_loss(gan_model,reduction='', add_summaries=True):
    return tf.reduce_mean(gan_model.discriminator_real_outputs**2)
#
# Initialize GANEstimator with options and hyperparameters.
gan_estimator = tfgan.estimator.GANEstimator(
    generator_fn=conditional_generator_fn,
    discriminator_fn=conditional_discriminator_fn,
    generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
    discriminator_loss_fn=lambda *args, **kwargs: (tfgan.losses.wasserstein_discriminator_loss(*args, **kwargs) + 
                                                   15*gradient_penaly(*args, **kwargs)+
                                                   0.001*silly_custom_discriminator_loss(*args, **kwargs)),   
    generator_optimizer=tf.compat.v1.train.AdamOptimizer(0.001, beta1=0.0, beta2=0.95),
    discriminator_optimizer=tf.compat.v1.train.AdamOptimizer(0.001, beta1=0.0, beta2=0.95),
    get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(5, 1)),
    model_dir='./test3_500steps', 
    config=my_config)


In [None]:
training_fn = partial(graph_input_fn, catalog,
                      scalar_features=('mass_scaled','central_bool', 'group_mass_scaled','dm_mass_scaled', 
                                       'tid_a_0.1_1024_scaled', 'tid_b_0.1_1024_scaled' ,'tid_c_0.1_1024_scaled' ,
                                       'tid_a_0.5_1024_scaled' ,'tid_b_0.5_1024_scaled' ,'tid_c_0.5_1024_scaled' ,
                                       'tid_a_1.0_1024_scaled' ,'tid_b_1.0_1024_scaled' ,'tid_c_1.0_1024_scaled' ), 
                      vector_features=('mlp_av_x','mlp_av_y','mlp_av_z', 'tid_av_x_0.1_1024', 'tid_av_y_0.1_1024', 'tid_av_z_0.1_1024',
                                       'tid_bv_x_0.1_1024', 'tid_bv_y_0.1_1024', 'tid_bv_z_0.1_1024',
                                       'tid_cv_x_0.1_1024', 'tid_cv_y_0.1_1024', 'tid_cv_z_0.1_1024',
                                         'tid_av_x_0.5_1024', 'tid_av_y_0.5_1024', 'tid_av_z_0.5_1024',
                                       'tid_bv_x_0.5_1024', 'tid_bv_y_0.5_1024', 'tid_bv_z_0.5_1024',
                                       'tid_cv_x_0.5_1024', 'tid_cv_y_0.5_1024', 'tid_cv_z_0.5_1024',
                                       'tid_av_x_1.0_1024', 'tid_av_y_1.0_1024', 'tid_av_z_1.0_1024',
                                       'tid_bv_x_1.0_1024', 'tid_bv_y_1.0_1024', 'tid_bv_z_1.0_1024',
                                       'tid_cv_x_1.0_1024', 'tid_cv_y_1.0_1024', 'tid_cv_z_1.0_1024'),
                      
                      scalar_labels=('q','s'),
                      vector_labels=('av_x', 'av_y', 'av_z',
                                    'bv_x', 'bv_y', 'bv_z',
                                    'cv_x', 'cv_y', 'cv_z'),
                      shuffle=True, rotate=True, repeat=True, noise_size=32, batch_size=64)

testing_fn = partial(graph_input_fn, catalog,
                      scalar_features=('mass_scaled','central_bool', 'group_mass_scaled','dm_mass_scaled',
                                      'tid_a_0.1_1024_scaled', 'tid_b_0.1_1024_scaled' ,'tid_c_0.1_1024_scaled' ,
                                       'tid_a_0.5_1024_scaled' ,'tid_b_0.5_1024_scaled' ,'tid_c_0.5_1024_scaled' ,
                                       'tid_a_1.0_1024_scaled' ,'tid_b_1.0_1024_scaled' ,'tid_c_1.0_1024_scaled' ),#,'tid_a_scaled','tid_b_scaled','tid_c_scaled'),
                      vector_features=('mlp_av_x','mlp_av_y','mlp_av_z','tid_av_x_0.1_1024', 'tid_av_y_0.1_1024', 'tid_av_z_0.1_1024',
                                       'tid_bv_x_0.1_1024', 'tid_bv_y_0.1_1024', 'tid_bv_z_0.1_1024',
                                       'tid_cv_x_0.1_1024', 'tid_cv_y_0.1_1024', 'tid_cv_z_0.1_1024',
                                         'tid_av_x_0.5_1024', 'tid_av_y_0.5_1024', 'tid_av_z_0.5_1024',
                                       'tid_bv_x_0.5_1024', 'tid_bv_y_0.5_1024', 'tid_bv_z_0.5_1024',
                                       'tid_cv_x_0.5_1024', 'tid_cv_y_0.5_1024', 'tid_cv_z_0.5_1024',
                                       'tid_av_x_1.0_1024', 'tid_av_y_1.0_1024', 'tid_av_z_1.0_1024',
                                       'tid_bv_x_1.0_1024', 'tid_bv_y_1.0_1024', 'tid_bv_z_1.0_1024',
                                       'tid_cv_x_1.0_1024', 'tid_cv_y_1.0_1024', 'tid_cv_z_1.0_1024' ),
                     
                      scalar_labels=('q','s'),
                      vector_labels=('av_x', 'av_y', 'av_z',
                                    'bv_x', 'bv_y', 'bv_z',
                                    'cv_x', 'cv_y', 'cv_z'),
                      shuffle=False, rotate=False, repeat=False, noise_size=32, batch_size=64)

In [2]:
#Assuming everything went fine up to this point the training will start when hit ENTER
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
gan_estimator.train(training_fn, steps=300000)