# Understanding the RIME formalism for visibility simulations

In this notebook, we discuss the radio interferometer measurement equation and explain how it can be used to simulate the visibilities for a radio interferometer observation. We follow explanations and conventions from O. Smirnov. For more background information we recommend to read his "Revisiting the radio interferometer measurement quation" paper series, which provides an excelent introduction and overview. ([Smirnov, A&A, Volume 527, March 2011](https://www.aanda.org/articles/aa/abs/2011/03/aa16082-10/aa16082-10.html))

Our software package $\texttt{pyvisgen}$ is a Python based implementation of the [VISGEN tool](https://github.com/piyanatk/MAPS/tree/master/visgen) developed at Haystack Observatory. We implemented the complete matrix calculation framework in PyTorch to enable GPU support. In the following, an introduction to the basic functionality is given.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from math import pi

## Radio signals from brightnes distributions

TODO: explain brightness distributions, radio signals, stokes, brightness matrix B (just basic, no polarization yet)


defining FoV and pixel size

In [None]:
sky = torch.zeros((128, 128))
sky[68, 68] = 0.75
sky[45, 21] = 1.00
sky[88, 73] = 1.12
sky[11, 25] = 0.68
sky[23, 34] = 1.04
sky[83, 85] = 0.97

fov = 100
px_size = fov / sky.shape[0]

pointing_ra = torch.tensor([7.45])
pointing_dec = torch.tensor([53.86])


fig, ax = plt.subplots(1)
im1 = ax.imshow(
    sky,
    cmap="inferno",
    origin="lower",
    extent=[-50, 50, -50, 50],
)
ax.set_xlabel("rel. R.A. / asec")
ax.set_ylabel("rel. Dec. / asec")
cbar = fig.colorbar(im1, ax=ax, location="right", shrink=1, pad=0.02)
cbar.set_label("Flux density / Jy$\cdot$px$^{-1}$")
ax.text(
    -48,
    40,
    f"FoV={fov}$\,$asec\npx={round(px_size, 2)}$\,$asec",
    ha="left",
    size=11,
    color="white",
)
ax.text(
    -5,
    40,
    f"pointing R.A.: {round(pointing_ra.item(), 2)}$\,$deg\npointing Dec.: {round(pointing_dec.item(), 2)}$\,$deg",
    ha="left",
    size=11,
    color="white",
)
None

In [None]:
# Create Stokes vector for unpolarized emission
I = torch.zeros((sky.shape[0], sky.shape[1], 4), dtype=torch.cdouble)
I[..., 0] = sky  # Stokes I
# I[...,1] = 0   Stokes Q
# I[...,2] = 0   Stokes U
# I[...,3] = 0   Stokes V

# Create brightness matrix
B = torch.zeros((sky.shape[0], sky.shape[1], 2, 2), dtype=torch.cdouble)
B[:, :, 0, 0] = I[:, :, 0] + I[:, :, 1]
B[:, :, 0, 1] = I[:, :, 2] + 1j * I[:, :, 3]
B[:, :, 1, 0] = I[:, :, 2] - 1j * I[:, :, 3]
B[:, :, 1, 1] = I[:, :, 0] - I[:, :, 1]

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
im1 = ax1.imshow(B[..., 0, 0].real, cmap="inferno", origin="lower")
im2 = ax2.imshow(B[..., 0, 1].real, cmap="inferno", origin="lower")
im3 = ax3.imshow(B[..., 1, 0].real, cmap="inferno", origin="lower")
im4 = ax4.imshow(B[..., 1, 1].real, cmap="inferno", origin="lower")

ax1.text(5, 115, "I + V", ha="left", size=11, color="white")
ax2.text(5, 115, "Q + iU", ha="left", size=11, color="white")
ax3.text(5, 115, "Q - iU", ha="left", size=11, color="white")
ax4.text(5, 115, "I - V", ha="left", size=11, color="white")

ax1.axis("off")
ax2.axis("off")
ax3.axis("off")
ax4.axis("off")
fig.tight_layout()

## Radio interferometer layout and coordinate systems

TODO: basics here, more insights in separate notebook?

In [None]:
# define toy interferometer layout
# antenna positions in earth centered coordinates (XYZ / meter)

ant1 = torch.tensor([-4, 2, 0]) * 1e4
ant2 = torch.tensor([8, -2, 0]) * 1e4
ant3 = torch.tensor([1, 1, 0]) * 1e4
ant4 = torch.tensor([-4, 6, 0]) * 1e4
ant5 = torch.tensor([-3, -3, 0]) * 1e4
ant6 = torch.tensor([6, -5, 0]) * 1e4
ant7 = torch.tensor([8, 2, 0]) * 1e4
ant8 = torch.tensor([2, -4, 0]) * 1e4
ant9 = torch.tensor([-6, 3, 0]) * 1e4
ant10 = torch.tensor([-2, 8, 0]) * 1e4

ants_rel = torch.stack([ant1, ant2, ant3, ant4, ant5, ant6, ant7, ant8, ant9, ant10])

In [None]:
# calc fov and maximum resolution of radio interferometer

obs_freq = 1.4e9
dish_diameter = 25

interferometer_fov = 3600 * (180 / pi) * (3e8 / obs_freq) / dish_diameter
print("FoV:", round(interferometer_fov, 2), " asec")

max_distance = torch.sqrt(((ant1 - ant2) ** 2).sum())
interferometer_res = 3600 * (180 / pi) * (3e8 / obs_freq) / max_distance
print("Maximum resolution:", round(interferometer_res.item(), 2), " asec")

print("\n")

fig, ax = plt.subplots(1)

ax.plot(ants_rel[:, 0], ants_rel[:, 1], ".")
ax.axhline(0, 1, 0, linestyle="--", color="gray")
ax.axvline(0, 1, 0, linestyle="--", color="gray")
ax.set_xlabel("X / m")
ax.set_ylabel("Y / m")
None

In [None]:
# define observation
from pyvisgen.simulation.observation import Observation
from astropy.time import Time
import datetime

obs = Observation(
    src_ra=pointing_ra,
    src_dec=pointing_dec,
    start_time=datetime.datetime(2023, 4, 22, 12, 00, 0),
    scan_duration=60,
    num_scans=3,
    scan_separation=120,
    integration_time=15,
    ref_frequency=1.3e9,
    frequency_offsets=[0],
    bandwidths=[64e8],
    fov=1024,
    image_size=128,
    array_layout="test_layout",
    corrupted=False,
    device="cpu",
    dense=False,
    sensitivity_cut=1e-6,
)

In [None]:
# define coordinate grids for the brigthnes matrix
# used to consider the relative distance from the pointing center

# ra dec grid

fov_rad = fov / 3600 * (pi / 180)

# define resolution
res_rad = fov_rad / B.shape[-1]

ra = torch.deg2rad(torch.tensor([0]))
dec = torch.deg2rad(90 - torch.tensor([90]))

r = (torch.arange(sky.shape[0], device="cpu") - (sky.shape[0]) / 2) * res_rad + ra
d = ((torch.arange(sky.shape[0], device="cpu") - (sky.shape[0]) / 2)) * res_rad + dec

_, R = torch.meshgrid((r, r), indexing="ij")
D, _ = torch.meshgrid((d, d), indexing="ij")

rd_grid = torch.cat([R[..., None], D[..., None]], dim=2)


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

im1 = ax1.imshow(rd_grid[..., 0])
ax1.set_xlabel("px")
ax1.set_ylabel("px")
cbar1 = fig.colorbar(im1, ax=ax1, location="right", shrink=0.62, pad=0.02)
cbar1.set_label("Rel. R.A. / rad")

im2 = ax2.imshow(rd_grid[..., 1])
ax2.set_xlabel("px")
ax2.set_ylabel("px")
cbar2 = fig.colorbar(im2, ax=ax2, location="right", shrink=0.62, pad=0.02)
cbar2.set_label("Rel. Dec. / rad")

fig.tight_layout()

In [None]:
# define l, m grid

lm_grid = torch.zeros(rd_grid.shape, device="cpu")
lm_grid[:, :, 0] = torch.cos(rd_grid[:, :, 1]) * torch.sin(rd_grid[:, :, 0] - ra).T
lm_grid[:, :, 1] = (
    torch.sin(rd_grid[:, :, 1]) * torch.cos(dec)
    - torch.cos(rd_grid[:, :, 1]) * torch.sin(dec) * torch.cos(rd_grid[:, :, 0] - ra)
).T


fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))

im1 = ax1.imshow(lm_grid[..., 0])
ax1.set_xlabel("px")
ax1.set_ylabel("px")
cbar1 = fig.colorbar(im1, ax=ax1, location="right", shrink=0.62, pad=0.02)
cbar1.set_label("l / rad")

im2 = ax2.imshow(lm_grid[..., 1])
ax2.set_xlabel("px")
ax2.set_ylabel("px")
cbar2 = fig.colorbar(im2, ax=ax2, location="right", shrink=0.62, pad=0.02)
cbar2.set_label("m / rad")

fig.tight_layout()

## Visibility calculations

In [None]:
# calculate Fourier kernel

u_cmplt = torch.cat([obs.baselines.u, obs.baselines.u])
v_cmplt = torch.cat([obs.baselines.v, obs.baselines.v])
w_cmplt = torch.cat([obs.baselines.w, obs.baselines.w])

l = lm_grid[..., 0]
m = lm_grid[..., 1]
n = torch.sqrt(1 - l**2 - m**2)

ul = torch.einsum("b,ij->ijb", u_cmplt, l) * 1.2e9 / 3e8
vm = torch.einsum("b,ij->ijb", v_cmplt, m) * 1.2e9 / 3e8
wn = torch.einsum("b,ij->ijb", w_cmplt, (n - 1)) * 1.2e9 / 3e8
K = torch.exp(-2 * pi * 1j * (ul + vm + wn))

In [None]:
K.shape, B.shape

In [None]:
# calculate source coherency matrix
X = B * K.swapaxes(0, 2).swapaxes(1, 2)[..., None, None]
X.shape

In [None]:
# sum over l,m and average uvw_start, uvw_stop
vis = X.sum((1, 2))
vis = 0.5 * torch.stack(torch.split(vis, int(vis.shape[0] / 2), dim=0)).sum(0)
vis.shape