# Creating Images in Multiple Bands from Particle distributions

In this example we show how to create various different types of images from stellar particles. For this purpose we utilise the parametric SFZH functionality to create fake galaxies, derive their spectra from the SPS grid and then make images using the Sythensizer Imaging submodule.

In [None]:
import os
import time
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from unyt import yr, Myr, kpc, arcsec
from astropy.cosmology import Planck18 as cosmo

from synthesizer.grid import Grid
from synthesizer.parametric.sfzh import SFH, ZH, generate_sfzh
from synthesizer.particle.stars import sample_sfhz
from synthesizer.particle.stars import Stars
from synthesizer.galaxy.particle import ParticleGalaxy as Galaxy
from synthesizer.particle.particles import CoordinateGenerator
from synthesizer.filters import FilterCollection as Filters
from synthesizer.kernel_functions import quintic


plt.rcParams["font.family"] = "DeJavu Serif"
plt.rcParams["font.serif"] = ["Times New Roman"]

# Set the seed
np.random.seed(42)

First port of call is initilaising the SPS grid. Here we use a simple test grid with limited properties.

In [None]:
# Define the grid
grid_name = "test_grid"
grid_dir = "../../../tests/test_grid/"
grid = Grid(grid_name, grid_dir=grid_dir)

With the grid in hand we need to define a star formation metallicity history from which to sample. In this toy example we use a constant SFH and metallicity history.

In [None]:
# Define the metallicity history
Z_p = {"Z": 0.01}
Zh = ZH.deltaConstant(Z_p)

# Define the star formation history
sfh_p = {"duration": 100 * Myr}
sfh = SFH.Constant(sfh_p)

# Initialise the SFZH object
sfzh = generate_sfzh(grid.log10ages, grid.metallicities, sfh, Zh, stellar_mass=10**9)

We can now sample this SFZH for individual stellar "particles" and create a Stars object. In a real world example the Stars can be intialised from simulation data (see a `cosmo` example to see how), here we generate random coordinates from a gaussian and simulate true smoothing lengths by making them increase with increasing radius.

Note that setting attributes in this way is only necessary when sampling from a SFZH. When working with data, attributes can be passed as kwargs when intialising a Stars object. Soon this will be updated.

In [None]:
stars_start = time.time()

# Define the number of stellar particles
n = 100000

# Generate some random coordinates
coords = CoordinateGenerator.generate_3D_gaussian(n)

# Sample the SFZH, producing a Stars
stars = sample_sfhz(sfzh, n)

# Assign our coordinates and their units to the Stars object
stars.coordinates = coords
stars.coord_units = kpc

# Similarly, set the initial masses.
stars.initial_masses = np.full(n, 10**9 / n)

# Calculate the smoothing lengths from radii
cent = np.mean(coords, axis=0)
rs = np.sqrt(
        (coords[:, 0] - cent[0]) ** 2
        + (coords[:, 1] - cent[1]) ** 2
        + (coords[:, 2] - cent[2]) ** 2
)
rs[rs < 0.1] = 0.4  # Set a lower bound on the "smoothing length"
stars.smoothing_lengths = rs / 4  # convert radii into smoothing lengths

# Finally, associate the stellar particles with a redshift for flux calculation
stars.redshift = 1

# Compute the width of stellar distribution, we'll use this to define the FOV later
width = np.max(coords) - np.min(coords)

print("Stars created, took:", time.time() - stars_start)

In [None]:
    galaxy_start = time.time()

    # Create galaxy object
    galaxy = Galaxy(stars=stars)

    print("Galaxy created, took:", time.time() - galaxy_start)

In [None]:
    spectra_start = time.time()

    # Calculate the stars SEDs
    sed = galaxy.generate_particle_spectra(grid, sed_object=True,
                                           spectra_type="stellar")
    sed.get_fnu(cosmo, stars.redshift, igm=None)

    print("Spectra created, took:", time.time() - spectra_start)

In [None]:
    filter_start = time.time()

    # Define filter list
    filter_codes = [
        "JWST/NIRCam.F090W",
        "JWST/NIRCam.F150W",
        "JWST/NIRCam.F200W",
    ]

    # Set up filter object
    filters = Filters(filter_codes, new_lam=grid.lam)

    print("Filters created, took:", time.time() - filter_start)

In [None]:
    img_start = time.time()

    # Define image propertys
    redshift = 1
    resolution = ((width + 1) / 100) * kpc
    width = (width + 1) * kpc

    # Get the image
    hist_img = galaxy.make_image(
        resolution,
        fov=width,
        img_type="hist",
        sed=sed,
        filters=filters,
        kernel_func=quintic,
        rest_frame=False,
        cosmo=cosmo,
    )

    print("Histogram images made, took:", time.time() - img_start)
    

In [None]:
    img_start = time.time()

    # Get the image
    smooth_img = galaxy.make_image(
        resolution,
        fov=width,
        img_type="smoothed",
        sed=sed,
        filters=filters,
        kernel_func=quintic,
        rest_frame=False,
        cosmo=cosmo,
    )

    print("Smoothed images made, took:", time.time() - img_start)

In [None]:
    hist_imgs = hist_img.imgs
    smooth_imgs = smooth_img.imgs

    print("Sucessfuly made images for:", [key for key in hist_imgs])

    print("Total runtime (not including plotting):", time.time() - start)

    # Set up plot
    fig = plt.figure(figsize=(4 * len(filters), 4 * 2))
    gs = gridspec.GridSpec(2, len(filters))

    # Create top row
    axes = []
    for i in range(len(filters)):
        axes.append(fig.add_subplot(gs[0, i]))

    # Loop over images plotting them
    for ax, fcode in zip(axes, filter_codes):
        ax.imshow(hist_imgs[fcode])
        ax.set_title(fcode)

    # Set y axis label on left most plot
    axes[0].set_ylabel("Histogram")

    # Create bottom row
    axes = []
    for i in range(len(filters)):
        axes.append(fig.add_subplot(gs[1, i]))

    # Loop over images plotting them
    for ax, fcode in zip(axes, filter_codes):
        ax.imshow(smooth_imgs[fcode])

    # Set y axis label on left most plot
    axes[0].set_ylabel("Smoothed")

    # Plot the image
    plt.savefig(script_path + "/plots/flux_in_filters_test.png",
                bbox_inches="tight", dpi=300)

    # Also, lets make an RGB image
    fig, ax, rgb_img = smooth_img.plot_rgb_image(
        rgb_filters={"R": ["JWST/NIRCam.F200W",],
                     "G": ["JWST/NIRCam.F150W",],
                     "B": ["JWST/NIRCam.F090W",]},
        img_type="standard",
    )

    fig.savefig(script_path + "/plots/flux_in_filters_RGB_test.png",
                bbox_inches="tight", dpi=300)