In [None]:
import os
import warnings
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo  # on windows need tzdata package

import numpy as np
# import plot_utils as pu
from astropy import units as u
from astropy.coordinates import ICRS, AltAz, EarthLocation, SkyCoord, get_body
from astropy.table import QTable, unique, vstack
from astropy.time import Time
from astropy.visualization.wcsaxes.frame import EllipticalFrame
from astropy.wcs import WCS
from astroquery.exceptions import NoResultsWarning
from astroquery.simbad import Simbad
from matplotlib import patheffects
from matplotlib import pyplot as plt
from matplotlib.colors import (LinearSegmentedColormap, ListedColormap,
                               LogNorm, to_rgb)
from timezonefinder import TimezoneFinder

foregroundcolour = "#FFF"
blues = ["#000", "#171726", "dodgerblue", "#00BFFF", "lightskyblue"]
magnitudes = [6, 4, 2, 0, -1]

brightness_by_time = {0: 0, 3: 1, 5: 2, 7: 3, 12: 4, 15: 3, 18: 2, 21: 1, 24: 0}

all_times = np.linspace(0, 24 * 60, 60 * 24 * 60)
magnitude_times = [key * 60 for key in brightness_by_time.keys()]
magnitude_values = [magnitudes[value] for value in brightness_by_time.values()]

magnitude_mapping = np.interp(all_times, magnitude_times, magnitude_values)

blue_mapping = [(key / 24, blues[value]) for key, value in brightness_by_time.items()]
backgroundcolours = LinearSegmentedColormap.from_list("sky", blue_mapping)

tf = TimezoneFinder()
Simbad.reset_votable_fields()
generic_maximum_magnitude = max(magnitudes)

spectral_colours = {
    "O": "lightskyblue",
    "B": "lightcyan",
    "A": "white",
    "F": "lemonchiffon",
    "G": "yellow",
    "K": "orange",
    "M": "lightpink",  # "#f9706b",
    "": "white",
    "moon": "white",
    "mercury": "white",
    "venus": "lemonchiffon",
    "mars": "orange",
    "jupiter": "white",
    "saturn": "white",
    "uranus": "white",
    "neptune": "white",
}
spectral_colours = {key: to_rgb(value) for key, value in spectral_colours.items()}

## User Inputs

In [None]:
location = "Toronto"

# orion
observation_point = (
    45 * u.deg,
    160 * u.deg,
)  # altitude (deg from horizon), azimuth (eastwards from north)

# jupiter
# observation_point = (70 * u.deg, 180 * u.deg)

view_radius = 15 * u.deg

observation_start = {
    "year": 2025,
    "month": 2,
    "day": 18,
    "hour": 19,
    "minute": 0,
    "second": 0,
}
observation_length = 2 * u.hour
observation_frequency = 5 * u.minute

if (observation_start["second"] == 0) and (observation_frequency >= 1 * u.minute):
    print_seconds = False
else:
    print_seconds = True

image_pixels = 250
fps = 2
image_directory = os.getcwd()
gif_fname = f"{image_directory}/SkySim.gif"
delete_frames = True

In [None]:
# query to a lat/long
earth_location = EarthLocation.of_address(location)
lat, lon = [l.to(u.deg).value for l in [earth_location.lat, earth_location.lon]]
timezone = ZoneInfo(tf.timezone_at(lat=lat, lng=lon))

frames = max(
    np.ceil((observation_length / observation_frequency).decompose()).astype(int), 1
)

dt_native_start = datetime(**observation_start, tzinfo=timezone)

if frames > 1:
    freq = timedelta(seconds=observation_frequency.to(u.s).value)
    dt_astropy = Time([dt_native_start + freq * i for i in range(frames)])
else:
    dt_astropy = Time(dt_native_start)

## Convert Observing Point to RA/Dec

In [None]:
# generate a coordinate frame for the observation
earth_frame = AltAz(
    obstime=dt_astropy,
    az=observation_point[1],
    alt=observation_point[0],
    location=earth_location,
)

# perform the conversion
ra_dec = SkyCoord(earth_frame.transform_to(ICRS()))

## Query SIMBAD to get a catalogue of objects with relevant data

In [None]:
Simbad.reset_votable_fields()
Simbad.add_votable_fields("otype", "V", "ids", "sp_type")
criteria = f"otype != 'err' AND V < {generic_maximum_magnitude}"
query_result = QTable(
    Simbad.query_region(ra_dec, radius=view_radius, criteria=criteria)
)

In [None]:
# clean up the result
columns_to_remove = [
    "coo_err_min",
    "coo_err_angle",
    "coo_wavelength",
    "coo_bibcode",
    "coo_err_maj",
]
for colname in columns_to_remove:
    query_result.remove_column(colname)

# rename columns
query_result.rename_column("main_id", "id")
query_result.rename_column("otype", "object_type")
query_result.rename_column("V", "magnitude")
query_result.rename_column("sp_type", "spectral_type")

In [None]:
# round columns
query_result["ra"] = query_result["ra"].round(5)
query_result["dec"] = query_result["dec"].round(5)
query_result["magnitude"] = query_result["magnitude"].round(3)
spectral_types = []
for i in query_result["spectral_type"].data:
    if (len(i) > 0) and (i[0] in spectral_colours.keys()):
        spectral_types.append(i[0])
    else:
        spectral_types.append("")
query_result["spectral_type"] = spectral_types

# create human-readable name column
query_result["ids_list"] = [i.split("|") for i in query_result["ids"]]
names_column = []
for id, namelist in zip(query_result["id"].data, query_result["ids_list"].data):
    item_names = [n[5:] for n in namelist if "NAME" in n]
    if len(item_names) == 0:
        names_column.append(id)
    elif len(item_names) == 1:
        names_column.append(item_names[0])
    else:
        names_column.append("/".join(item_names))
query_result["name"] = names_column
query_result.remove_columns(["ids", "ids_list"])

## Filter data

In [None]:
# remove child elements
parents = query_result["id"]  # check all items, regardless of type
parents_string = tuple(parents.data)
parent_query_adql = f"""
    SELECT main_id AS "child_id",
    parent_table.id AS "parent_id"
    FROM (SELECT oidref, id FROM ident WHERE id IN {parents_string}) AS parent_table,
    basic JOIN h_link ON basic.oid = h_link.child
    WHERE h_link.parent = parent_table.oidref;
"""
with warnings.catch_warnings(action="ignore", category=NoResultsWarning):
    hierarchies = Simbad.query_tap(parent_query_adql)
children = unique(hierarchies)["child_id"].data
query_result.add_index("id")
for child_id in children:
    if child_id in query_result["id"]:
        query_result.remove_rows(query_result.loc_indices[child_id])

query_result = unique(query_result)

## Generate Planet Tables

In [None]:
ss_bodies_times = []

ss_bodies = QTable()
ss_bodies["name"] = [
    "mercury",
    "venus",
    "mars",
    "jupiter",
    "saturn",
    "uranus",
    "neptune",
]
ss_bodies["magnitude.offset"] = [
    -0.613,
    -4.384,
    -1.601,
    -9.395,
    -8.914,
    -7.11,
    -7,
]

for i in range(frames):
    time_table = QTable(
        names=query_result.colnames,
        dtype=query_result.dtype,
        units=[None, "deg", "deg", None, None, None, None],
    )
    sun = get_body("sun", dt_astropy[i], location=earth_location)
    earth = get_body("earth", dt_astropy[i], location=earth_location)

    for name, mag_offset in ss_bodies[["name", "magnitude.offset"]]:
        body_coord = get_body(name, dt_astropy[i], location=earth_location)
        sun_distance = body_coord.separation_3d(sun).to(u.au).value
        earth_distance = body_coord.separation_3d(earth).to(u.au).value

        row = {
            "ra": body_coord.ra,
            "dec": body_coord.dec,
            "magnitude": round(
                5 * np.log10(sun_distance * earth_distance) + mag_offset, 3
            ),
            "spectral_type": name,
            "name": name,
            "id": name,
        }
        time_table.add_row(row)
    time_table["ra"] = time_table["ra"].round(5)
    time_table["dec"] = time_table["dec"].round(5)
    ss_bodies_times.append(time_table)

## Calculate image parameters

In [None]:
degrees_per_pixel = (view_radius / (image_pixels / 2)).to(u.deg).value

airy_disk_minimum = 23 * u.arcmin / 2  # based on Vega spread in SIMBAD image
airy_disk_pixels = airy_disk_minimum.to(u.deg) / (degrees_per_pixel * u.deg)

# assume the airy disk is at 3x standard deviation of the Gaussian
std_dev = airy_disk_pixels / 3

# standard deviation is "diameter" whilst airy is "radius"
std_dev *= 2


def get_intensity(radius, flux, sigma):
    """
    How much light is observed from a star at some radius away from it
    """
    exponential = np.exp(-(radius**2) / (sigma**2))

    return flux * exponential

## Add stars to image

In [None]:
image = np.zeros((frames, 3, image_pixels, image_pixels))
backgrounds = []
if frames < 2:
    dt_astropy = [dt_astropy]
for i in range(frames):
    utc_datetime = dt_astropy[i].to_datetime().replace(tzinfo=ZoneInfo("UTC"))
    local_datetime = utc_datetime.astimezone(tz=timezone)

    day_percentage = timedelta(
        hours=local_datetime.hour, minutes=local_datetime.minute
    ).total_seconds() / (24 * 60 * 60)

    backgroundcolour = np.array(backgroundcolours(day_percentage)[:-1])
    x = np.ones_like(image[i]).T * backgroundcolour
    image[i] = np.swapaxes(np.ones_like(image[i]).T * backgroundcolour, 0, -1)
    backgrounds.append(backgroundcolour)

In [None]:
maxradius = np.ceil(10 * std_dev).astype(
    int
)  # calculate contribution out to 5 standard deviations
radius_vector = np.arange(-maxradius, maxradius + 1)
area = np.array(
    np.meshgrid(radius_vector, radius_vector)
)  # mesh of points which will map to around the star

radial_distance = np.sqrt(
    area[0] ** 2 + area[1] ** 2
)  # radius measurement at each of the meshgrid points

unique_radii = np.unique(radial_distance)

# all of the locations where each unique radius is found
radius_locations = [
    np.array(np.where(radial_distance == radius)).T for radius in unique_radii
]

wcs_by_frame = []
if frames < 2:
    ra_dec = [ra_dec]
for i in range(frames):
    wcs = WCS(naxis=2)
    wcs.wcs.crpix = [image_pixels / 2] * 2
    wcs.wcs.cdelt = [degrees_per_pixel, degrees_per_pixel]
    wcs.wcs.crval = [ra_dec[i].ra.value, ra_dec[i].dec.value]
    wcs.wcs.ctype = ["RA", "DEC"]
    wcs.wcs.cunit = [u.deg, u.deg]
    wcs_by_frame.append(wcs)

    utc_datetime = dt_astropy[i].to_datetime().replace(tzinfo=ZoneInfo("UTC"))
    local_datetime = utc_datetime.astimezone(tz=timezone)
    day_seconds = int(
        timedelta(
            hours=local_datetime.hour, minutes=local_datetime.minute
        ).total_seconds()
    )

    all_objects = vstack([query_result, ss_bodies_times[i]])

    bright_objects = all_objects[
        all_objects["magnitude"] < magnitude_mapping[day_seconds]
    ]
    object_xy = np.array(
        [
            np.round(SkyCoord(ra=ra, dec=dec).to_pixel(wcs)).astype(int)
            for ra, dec in bright_objects[["ra", "dec"]]
        ]
    )
    bright_objects["x"] = object_xy[:, 0]
    bright_objects["y"] = object_xy[:, 1]

    visible_objects = bright_objects[-maxradius <= bright_objects["x"]]
    visible_objects = visible_objects[
        visible_objects["x"] < image.shape[-1] + maxradius
    ]
    visible_objects = visible_objects[-maxradius <= visible_objects["y"]]
    visible_objects = visible_objects[
        visible_objects["y"] < image.shape[-2] + maxradius
    ]

    if len(visible_objects) == 0:
        continue

    visible_objects["flux"] = 10 ** (
        -visible_objects["magnitude"] / 2.5
    )  # relative to V-band reference flux
    visible_objects["flux"] = visible_objects["flux"].round(5)
    visible_objects["scaled_flux"] = np.log10(visible_objects["flux"])
    visible_objects["scaled_flux"] = (
        visible_objects["scaled_flux"] - np.nanmin(visible_objects["scaled_flux"]) + 0.2
    )
    visible_objects["scaled_flux"] /= np.nanmax(visible_objects["scaled_flux"])
    visible_objects["scaled_flux"] = visible_objects["scaled_flux"].round(5)

    print(
        f"{local_datetime.time()} max mag={magnitude_mapping[day_seconds]:.2f} # objects={len(visible_objects)}"
    )

    for ra, dec, flux, spectral_type in visible_objects[
        ["ra", "dec", "scaled_flux", "spectral_type"]
    ]:
        skycoord = SkyCoord(ra=ra, dec=dec)  # star location
        x, y = np.round(skycoord.to_pixel(wcs)).astype(int)

        # do all points at some given radius at once
        for radius, points in zip(unique_radii, radius_locations):
            brightness = get_intensity(radius, flux, std_dev)

            for area_x, area_y in points:
                # get the pixel locations where this radius applies
                x_ = x + area[1][area_x, area_y]
                y_ = y + area[0][area_x, area_y]

                if (0 <= x_ < image.shape[-1]) and (0 <= y_ < image.shape[-2]):
                    current_rgb = image[i, :, y_, x_]
                    star_rgb = spectral_colours[spectral_type]
                    star_weight = brightness
                    sky_weight = 1 - brightness
                    new_rgb = np.average(
                        [current_rgb, star_rgb],
                        weights=[sky_weight, star_weight],
                        axis=0,
                    )
                    image[i, :, y_, x_] = new_rgb

## Plot result

In [None]:
gif_frame_paths = [f"{image_directory}/SkySim_{i}.png" for i in range(frames)]

def write_frame(i, image_single, dt_astropy_single, wcs):
    fig, ax = plt.subplots(subplot_kw={"projection":wcs,"frame_class": EllipticalFrame})

    utc_datetime = dt_astropy_single.to_datetime().replace(tzinfo=ZoneInfo("UTC"))
    local_datetime = utc_datetime.astimezone(tz=timezone)

    day_percentage = timedelta(
        hours=local_datetime.hour, minutes=local_datetime.minute
    ).total_seconds() / (24 * 60 * 60)

    cmap = LinearSegmentedColormap.from_list(
        "sky", [backgroundcolours(day_percentage), foregroundcolour]
    )

    rgb = np.swapaxes(image_single, 0, -1)
    rgb = np.swapaxes(rgb, 0, 1)
    im = ax.imshow(rgb, vmin=0, vmax=1)  # , cmap=cmap)

    # pu.style_wcs_axes(ax, axis_ticks=(False, False))
    ax.invert_xaxis()
    ax.coords.frame.set_linewidth(0)

    location_string = f"{location}"
    if print_seconds:
        datetime_string = local_datetime.strftime("%Y-%m-%d %X %Z")
    else:
        datetime_string = local_datetime.strftime("%Y-%m-%d %H:%M %Z")
    altaz_string = f"Altitude: {observation_point[0].to_string(format="latex")}, Azimuth: {observation_point[1].to_string(format="latex")}, FOV: {(2*view_radius).to_string(format="latex")}"
    ax.set_title(f"{location_string} {datetime_string}\n{altaz_string}")

    fig.savefig(f"{image_directory}/SkySim_{i}.png")

import multiprocessing

with multiprocessing.Pool() as pool:
    args = [(i, image[i], dt_astropy[i], wcs_by_frame[i]) for i in range(frames)]
    pool.starmap(write_frame, args)

In [None]:
# from wand.image import Image

# gif_sequence = Image()
# for fname in gif_frame_paths:
#     frame = Image(filename=fname)
#     frame.delay = int(100 / fps)
#     gif_sequence.sequence.append(frame)

# gif_sequence.save(filename=gif_fname)

# del gif_sequence

# if delete_frames:
#     for fname in gif_frame_paths:
#         os.remove(fname)