In [None]:
import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt

In [None]:
def clip_image(data, lower = 0.1, upper = 1.):
    """Clip image to [lower,upper] quantile.
    
    This function is called by the get_image function of the Galaxy class defined in the galaxy.py file. It clips the image to the [lower,upper] quantile.

    Parameters
    ----------
    data : numpy.array
        The image to be clipped.
    lower : float, optional
        The lower quantile. The default is 0.1.
    upper : float, optional
        The upper quantile. The default is 1.

    Returns
    -------
    numpy.array
        The clipped image.
    """    
    hist = data.copy()
    L,U = np.quantile(hist,[lower,upper])
    hist = np.clip(hist, L, U)
    return(hist)


In [None]:

def volume(hist ,opacity = .1, isomin = None, isomax = None, surface_count = 30, 
           add_small_number = True, norm_hist = True,cmap ="magma", save_path = None,show= True,
           elevation =-60, azimuth = 0, distance = 300,remove_axis_titles = False,upper = 1.,lower =0.2,clip = True, save_html=False,
           **kwargs):
    '''Visualise a 3D histogram as a volume.
    
    Uses plotly to visualise a 3D histogram as a volume. The volume can be normalised and a small number can be added to the histogram to avoid visualising empty space.
    
    Parameters
    ----------
    hist : numpy.array
        The 3D histogram to visualise.
    opacity : float, optional
        The opacity of the volume. The default is .1.
    isomin : float, optional
        The minimum value of the volume. The default is 0.
    isomax : float, optional
        The maximum value of the volume. The default is None.
    surface_count : int, optional
        The number of surfaces to use for the volume. The default is 30.
    add_small_number : bool, optional
        Whether to add a small number to the histogram to avoid visualizing empty space. The default is True.
    norm_hist : bool, optional
        Whether to normalise the histogram. The default is True.
    **kwargs :
        Additional arguments to pass to the normalisation function.
    '''
    if isomin is None: isomin = hist.min()
    if isomax is None: isomax = hist.max()
    data_hist =hist.copy()
    
    if clip:
        data_hist = clip_image(data_hist,lower,upper)
    
    if add_small_number == True:
        data_hist += 1e-10
    xx, yy, zz = np.where(data_hist != 0)
    s = data_hist[xx,yy,zz]
    #Set figure size 1000x1000
    # set opacity of low values to zero
    fig = go.Figure(data=go.Volume(
        x=xx,
        y=yy,
        z=zz,
        value=s,
        isomin=isomin,
        isomax=isomax,
        opacity=opacity, # needs to be small to see through all surfaces
        surface_count=surface_count,# needs to be a large number for good volume rendering
        colorscale=cmap,
        showscale=False,
        ), layout = dict(width = 1000, height = 1000))
    fig.update_layout(scene_xaxis_showticklabels=False,
                  scene_yaxis_showticklabels=False,
                  scene_zaxis_showticklabels=False,coloraxis_showscale=False
                  )
    # Remove axis titles
    if remove_axis_titles:
        fig.update_layout(scene_xaxis_title='',
                  scene_yaxis_title='',
                  scene_zaxis_title='',
                  scene_xaxis_showticklabels=False,
                  scene_yaxis_showticklabels=False,
                  scene_zaxis_showticklabels=False)
    fig.update_layout(
    scene=dict(
        xaxis=dict(
            showbackground=False,
            showgrid=False,
            showline=False,
            ticks='',
            showticklabels=False
        ),
        yaxis=dict(
            showbackground=False,
            showgrid=False,
            showline=False,
            ticks='',
            showticklabels=False
        ),
        zaxis=dict(
            showbackground=False,
            showgrid=False,
            showline=False,
            ticks='',
            showticklabels=False
        )
    ),
    showlegend=False,  # Hide the legend if applicable
    plot_bgcolor='rgba(0,0,0,0)'  # Transparent background
)
    # Set camera
    azimuth = np.deg2rad(azimuth)
    elevation = np.deg2rad(elevation)
  

    # convert spherical to cartesian coordinates
    eye_x = distance * np.cos(elevation) * np.cos(azimuth)
    eye_y = distance * np.cos(elevation) * np.sin(azimuth)
    eye_z = distance * np.sin(elevation)

    camera = dict(
        eye=dict(x=eye_x, y=eye_y, z=eye_z)
    )

    fig.update_layout(scene_camera=camera)
    
    if save_html:
        offline.plot(fig, filename="fig.html", auto_open=False)
    
    if save_path is not None:
        fig.write_image(save_path)
    if show:
        fig.show()

In [None]:
from tqdm import trange
fields = ["Masses","Stellar Age", "Metallicity"]
cmaps = ["cividis","gist_heat","magma"]
cmaps = {"Masses": "cividis", "Stellar Age": "hot", "Metallicity": "magma"}
cmaps = ["cividis","hot","magma"]

In [None]:
import pickle
with open("galaxy_mPCA.pkl","rb") as f:
    data3d = pickle.load(f)
    
with open("2dmodel60.pkl","rb") as f:
    data2d = pickle.load(f)

In [None]:
mean3d = data3d["pca"].mean_.reshape(3,64,64,64)

In [None]:
fields_ = ["Metallicity", "Stellar Age", "Masses"]
cmaps = {"Masses": "cividis", "Stellar Age": "hot", "Metallicity": "magma"}
for i in range(3):
    m = mean3d[i]
    cmap = cmaps[fields_[i]]
    save_name = "3d/mean_3d_{}.png".format(fields_[i])
    volume(m, opacity = 0.15, isomin = None, isomax = None, surface_count = 40, clip = False,
       add_small_number = True, cmap = cmap, save_path = save_name, show = False, distance = 1.85, 
       elevation = 90, azimuth = 0, remove_axis_titles=True, quant=0., lower=0.5)
    

In [None]:



import os 
os.makedirs("3d", exist_ok = True)
cmaps = ["cividis","hot","magma"]
for i in trange(3):
    for j in range(3):
        save_name = f"3d/galaxy_{j}_{fields[i]}.png"
        cmap = cmaps[i]
        volume(gal[i][j], opacity = 0.15, isomin = None, isomax = None, surface_count = 40, clip = False,
       add_small_number = True, cmap = cmap, save_path = save_name, show = False, distance = 1.85, 
       elevation = 45, azimuth = 0, remove_axis_titles=True, quant=0., lower=0.5)
        # volume(gal[i][j], opacity = .1, isomin = None, isomax = None, surface_count = 35,
                # add_small_number = True, cmap = cmap, save_path = save_name, show = False, distance = 2,
                    # elevation = 30, azimuth = 0, remove_axis_titles=True)
        # contour3d(gal[i][j], opacity = .5, contour_count = 50,
                #  colormap = cmap, show = False, distance = 200, save_path=save_name,
                    # elevation = 45, azimuth = 0)
        # mayaviContour(gal[i][j], colorbar = False, show = False,azimuth = 0, elevation = -60, distance = 200, contour_count = 500, opacity = .01 , save_path = "galaxy_{}_{}.png".format(j,fields[i]), cmap = cmaps[i])
        # mlab.close(all=True)
    
    

In [None]:


i = 1
j = 2

save_name = f"3d/galaxy_{j}_{fields[i]}.png"
save_name = None
cmap = cmaps[i]
volume(gal[i][j], opacity = 0.15, isomin = None, isomax = None, surface_count = 20, clip = False,
add_small_number = True, cmap = cmap, save_path = save_name, show = True, distance = 1.85, 
elevation = 45, azimuth = 0, remove_axis_titles=True, quant=0., lower=0.5)
        # volume(gal[i][j], opacity = .1, isomin = None, isomax = None, surface_count = 35,
                # add_small_number = True, cmap = cmap, save_path = save_name, show = False, distance = 2,
                    # elevation = 30, azimuth = 0, remove_axis_titles=True)
        # contour3d(gal[i][j], opacity = .5, contour_count = 50,
                #  colormap = cmap, show = False, distance = 200, save_path=save_name,
                    # elevation = 45, azimuth = 0)
        # mayaviContour(gal[i][j], colorbar = Fals

# Reconstruction

In [None]:
from tqdm import trange
fields = ["GFM_Metallicity","GFM_StellarFormationTime", "Masses"]
cmaps = {"Masses": "cividis", "Stellar Age": "hot", "Metallicity": "magma"}


In [None]:
eigen3d = np.load("eigengalaxies.npy")
scores3d = np.load("scores.npy")
means3d = np.load("means.npy")


In [None]:
import sys
sys.path.insert(0, '../src')

from megs.model.mPCA import mPCA
from megs.data import image, DataLoader, Galaxy

In [None]:
data = DataLoader("/export/home/ucakir/MEGS/MEGS/src/megs/data/galaxy_data.hdf5", m_min = 8)


In [None]:
norm = image.norm # Normalization function
lower = 0.25
upper = 1.0
norm_function_args = {"Masses": {"takelog": True, "plusone": True,"lower": lower, "upper": upper},
                     "GFM_Metallicity": {"takelog": True, "plusone": True, "lower": 0.25, "upper": upper},
                     "GFM_StellarFormationTime": {"takelog": True, "plusone": True, "lower": 0.25, "upper": upper},
                    
}



In [None]:
index = 24
original = [norm(data.get_image(particle_type="stars",index = index, field = f, dim = 3),**norm_function_args[f]) for f in fields]

In [None]:
original = np.array(original)

In [None]:

ncomp = 215
eigen = eigen3d
means = means3d 
scores = scores3d
og = original 
score = scores[index][:ncomp]
eig = eigen[:ncomp]
eig = eig.reshape(ncomp, np.prod(og.shape))
means = means.reshape(-1)
reconstructed = np.dot(score, eig) + means
reconstructed = reconstructed.reshape(len(fields), 64, 64,64)

residual = og - reconstructed
isomin = residual.min()
isomax = residual.max()

In [None]:
volume_params = {"opacity": 0.15, "isomin": None, "isomax": None, "surface_count": 40, "clip": False,
                    "add_small_number": True, "show": False, "distance": 1.85,
                    "elevation": 45, "azimuth": 0, "remove_axis_titles":True, "quant":0., "lower":0.5}

In [None]:
import os
os.makedirs("3d/reconstruction", exist_ok = True)
c = ["magma","hot", "cividis"]
f = ["Metallicity", "Stellar Age", "Masses"]
for field_index in trange(len(fields)):
    volume(og[field_index], **volume_params, cmap = c[field_index],
           save_path =f"3d/reconstruction/original_24_{f[field_index]}_{ncomp}_eigen.png")
    


In [None]:
field_index = 1
f = ["Metallicity", "Stellar Age", "Masses"]
isomin = residual.min()
isomax = residual.max()

isomin = None
isomax = None
p = {"opacity": 0.15, "isomin": isomin, "isomax": isomax, "surface_count": 40, "clip": False,
                    "add_small_number": True, "show": True, "distance": 1.85,
                    "elevation": 45, "azimuth": 0, "remove_axis_titles":True, "quant":0., "lower":0.1}
#volume(residual[field_index], **p, cmap = "rdbu_r")
c = ["magma","hot", "cividis"]
volume(og[field_index], **p, cmap = c[field_index])

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import numpy as np

fig = plt.figure(figsize=(10., 10.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                    nrows_ncols=(3, 3),  # creates 2x2 grid of axes
                    axes_pad=0.1,  # pad between axes in inch.
                    )

fields_res = ["metallicity","age", "masses"]

fields = ["Metallicity","Stellar Age", "Masses"]
# Change order of fields

def crop_image(img):
    amount = 100
    img = img[amount:-amount,amount:-amount,:]
    return img

for i in range(3):
    field_res = fields_res[i]
    field = fields[i]
    res = plt.imread(f"3d/reconstruction/residual_{field_res}.png")
    og = plt.imread(f"3d/reconstruction/original_24_{field}.png")
    rec = plt.imread(f"3d/reconstruction/reconstructed_24_{field}.png")
    # Crop image
    res = crop_image(res)
    og = crop_image(og)
    rec = crop_image(rec)
    
    grid[i*3+0].imshow(og)
    grid[i*3+1].imshow(rec)
    grid[i*3+2].imshow(res)
    [ax.set_axis_off() for ax in grid[i*3:i*3+3]]
    #if i == 0:
    #    grid[i*3+j].set_title(field)
            
plt.tight_layout(w_pad = 0.0, h_pad = 0.0)
#plt.savefig("galaxy_grid.pdf", dpi = 300, bbox_inches = "tight")

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import matplotlib.gridspec as gridspec
import numpy as np

fig = plt.figure(figsize=(10, 10))

gs = gridspec.GridSpec(3, 4, width_ratios=[1, 1, 1, 0.1])

fields_res = ["metallicity", "age", "masses"]
fields = ["Metallicity", "Stellar Age", "Masses"]

# Function to crop the image
def crop_image(img):
    amount = 110
    img = img[amount:-amount, amount:-amount, :]
    return img

# Function to add a custom colorbar


for i in range(3):
    field_res = fields_res[i]
    field = fields[i]
    res = plt.imread(f"3d/reconstruction/residual_{field_res}.png")
    og = plt.imread(f"3d/reconstruction/original_24_{field}.png")
    rec = plt.imread(f"3d/reconstruction/reconstructed_24_{field}.png")
    # Crop image
    res = crop_image(res)
    og = crop_image(og)
    rec = crop_image(rec)
    ax1 = fig.add_subplot(gs[i, 0])
    ax2 = fig.add_subplot(gs[i, 1])
    ax3 = fig.add_subplot(gs[i, 2])
    ax1.imshow(og)
    ax2.imshow(rec)
    ax3.imshow(res)
    [ax.set_axis_off() for ax in [ax1, ax2, ax3]]
    
    


labelsize = 20

vmin, vmax = isomin, isomax
norm = plt.Normalize(vmin=vmin, vmax=vmax)
sm = plt.cm.ScalarMappable(cmap="RdBu_r", norm=norm)

cax = fig.add_subplot(gs[:,3])
cbar = plt.colorbar(sm, cax=cax)
cbar.ax.tick_params(labelsize=10)
cbar.set_label("Residue", fontsize=labelsize)

'''

#add an axis for the colorbar
cax = fig.add_axes([0.9, 0.1, 0.03, 0.8])
#add the colorbar to the axis

sm.set_array([])
cbar = plt.colorbar(sm, ax=cax, pad=0.05)

'''
plt.tight_layout()


plt.savefig("3d/reconstruction/3dreconstruction.pdf", dpi=300, bbox_inches="tight")


In [None]:
plt.Normalize(0,1)