# Analyze gradients of images wrt sub masses

In [1]:
import sys, os
sys.path.append('../')

import logging
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from simulation.wrapper import augmented_data
from simulation.units import M_s
import astropy

logging.basicConfig(
    format='%(asctime)-5.5s %(name)-20.20s %(levelname)-7.7s %(message)s',
    datefmt='%H:%M',
    level=logging.INFO
)

In [2]:
astropy.__version__

'3.1.2'

## Settings

In [5]:
n_cols = 2
n_rows = 2

## Make data

In [None]:
theta, imgs, t_xz, log_r_xz, sub_latents, global_latents = augmented_data(
    f_sub=0.05, beta=-1.9, n_images=n_cols*n_rows, mine_gold=True,
    draw_alignment=False, draw_host_mass=False, draw_host_redshift=False,
    calculate_dx_dm=True, roi_size=5.
)

16:32 simulation.wrapper   INFO    Simulating image 1 / 4
16:37 simulation.wrapper   INFO    Simulating image 2 / 4


## Latents

In [None]:
for label, data in zip(
    ["M_200_hst", "D_l", "z_l", "sigma_v", "theta_x_0", "theta_y_0", "theta_E",
     "n_sub_roi", "f_sub_realiz", "n_sub_in_ring", "f_sub_in_ring", "n_sub_near_ring", "f_sub_near_ring"],
    global_latents.T
):
    print(label, "=", data)

## Plot subhalos

In [None]:
plt.figure(figsize=(n_cols*5.,n_rows*4.))

for i in range(n_cols*n_rows):
    ax = plt.subplot(n_rows,n_cols,i+1)
        
    plt.imshow(
        np.log10(imgs[i]),
        vmin=2.3,
        vmax=3.2,
        cmap='gist_gray',
        extent=(-3.2,3.2,-3.2,3.2),
        origin="lower"
    )
    sc = plt.scatter(
        sub_latents[i][:,1],
        sub_latents[i][:,2],
        c=1.e6*sub_latents[i][:,3]*M_s,
        s=195 + 30 * np.log10(sub_latents[i][:,0] / global_latents[i,0]),
        cmap="plasma",
        vmin=0.,
        vmax=3.,
    )
    plt.scatter(
        [global_latents[i,4]],
        [global_latents[i,5]],
        c=["white"],
        marker="+",
        s=20.,
    )
    cbar = plt.colorbar(sc)
    
    plt.xlim(-3.2,3.2)
    plt.ylim(-3.2,3.2)

    plt.xlabel(r"$\theta_x$ [arcsec]")
    plt.ylabel(r"$\theta_y$ [arcsec]")
    cbar.set_label(r'$\sum_{ij} \; |\nabla_{m} \; \bar{x}_{ij}|$ [$10^6\;M_S$]')
    
    
plt.tight_layout()
plt.savefig("../figures/subhalo_gradients.pdf")

## Relation between subhalo properties and gradient

In [None]:
f_subs = np.concatenate(
    [sub_latents[i][:,0] / global_latents[i,0] for i in range(n_cols*n_rows)],
    axis=0
)
grads = np.concatenate(
    [sub_latents[i][:,3] for i in range(n_cols*n_rows)],
    axis=0
) * M_s
rnorms = np.concatenate(
    [(sub_latents[i][:,1]**2 + sub_latents[i][:,2]**2)**0.5 / global_latents[i,6] for i in range(n_cols*n_rows)],
    axis=0
)

In [None]:
plt.figure(figsize=(5,4))
ax = plt.gca()

sc = plt.scatter(
    np.log10(f_subs),
    rnorms,
    c=grads,
    marker="o",
    s=20.,
    cmap="plasma_r",
    vmin=0.,
    vmax=3.e-6,
)
cbar = plt.colorbar(sc)

plt.tight_layout()
plt.show()