# CASA SIMULATION

The simulation follows the guidelines and examples provided in the CASA documentation. For more detailed information on the simulation process, you can refer to the: https://casadocs.readthedocs.io/en/v6.2.0/examples/community/simulation_script_demo.html 

In this notebook, we will simulate a 4-h observation in SKA Mid Band 2.


## Prerequisites

### Installing casa

1. Visit https://casa.nrao.edu/casa_obtaining.shtml
1. Install casa wherever you want (prefer local install for the specific user)

### Installing additional dependencies

1. Once casa is installed, run it using `casa` which will open casa specific ipython shell
1. Run following command to install additional dependencies:

    ```
    %pip install ipykernel astropy
    ```

## Imports

In [7]:
import os
import shutil
import time
from typing import List

import numpy as np
from astropy import units as au

In [None]:
from casatasks.private import simutil

mysu = simutil.simutil()
tstart = time.time()

## Utilities

In [9]:
def remove_if_exists(path: str):
    """
    If path exists, either file or directory, then delete it
    """
    if os.path.exists(path):
        if os.path.isdir(path):
            shutil.rmtree(path)
        else:
            os.remove(path)

In [10]:
def remove_whitespace_in_astropy_unit(value: au.Unit) -> str:
    """
    Removes whitespace from the string format of astropy unit.
    Input must be a valid astropy unit object.

    Example: 950 MHz -> "950Mhz"
    """
    return "".join(str(value).split())

## Simulate visibilities from the sky model

In [11]:
# ************* Simulate the observation *********************#


def simulate(
    msname: str,
    ant_config: str,
    clname: str,
    imagename: str,
    itime: float,
    starttime: float,
    endtime: float,
    odate: List[str],
    freq: str,
    deltafreq: str,
    freqresolution: str,
    nchannels: int,
    usemod: str,
):
    """
    Construct an empty Measurement Set that has the desired observation setup.

    Parameters
    ----------
    msname: str
    ant_config: str
    clname: str
    imagename: str
    itime: float
        Integration Time in seconds
    starttime: float
        Start Time in seconds
    endtime: float
        End Time in seconds
    odate: List[str]
        List containing observation dates.
        Each data is a standard date-time formatted string.
    freq: str
        Starting frequency, represented in string Quantity. e.g. '950Mhz'
    deltafreq: str
        Increment in frequency
    freqresolution: str
        Frequency resolution
    nchannels: int
    usemod: str
    """

    # Delete the directory
    remove_if_exists(msname)

    dir0 = me.direction("J2000", "15h00m00", "-30d00m00")
    sname = "SKA_MID_SOURCE"
    sm.open(msname)  # Open the simulator

    # intialize the ms file; has to be done before setting array configs, spw, etc; i.e. this is the first step of the simulation
    # beam  =  vp.setpbairy(telescope = 'PAPER_SA', dishdiam = '14.0m',blockagediam = '1.0m', dopb = True)
    (x, y, z, d, an, an2, telname, obspos) = mysu.readantenna(ant_config)
    # Set the antenna configuration
    sm.setconfig(
        telescopename=telname,
        x=x,
        y=y,
        z=z,
        dishdiameter=d,
        mount=["alt-az"],
        antname=an,
        coordsystem="global",
        referencelocation=obspos,
    )

    # Set the spectral window and polarization (one data-description-id).
    # Call multiple times with different names for multiple SPWs or pol setups.

    sm.setspwindow(
        spwname="BAND 2",
        freq=freq,
        deltafreq=deltafreq,
        freqresolution=freqresolution,
        nchannels=nchannels,
        stokes="RR LL",
    )
    # Set the polarization mode (this goes to the FEED subtable)
    sm.setfeed(mode="perfect R L", pol=[""])
    sm.setfield(sourcename=sname, sourcedirection=dir0)
    # Leave autocorrelations out of the MS.
    sm.setauto(autocorrwt=0.0)

    # Set the integration time, and the convention to use for timerange specification
    # Note : It is convenient to pick the hourangle mode as all times specified in sm.observe()
    # will be relative to when the source transits.
    # Construct MS metadata and UVW values for one scan and ddid
    # Call multiple times for multiple scans.
    # Call this with different sourcenames (fields) and spw/pol settings as defined above.
    # Timesteps will be defined in intervals of 'integrationtime', between starttime and stoptime.
    obsdate = odate
    refdate = obsdate[0]
    reftime = me.epoch("UTC", refdate)
    netime = (
        (
            me.riseset(dir0)["set"]["last"]["m0"]["value"]
            - me.riseset(dir0)["rise"]["last"]["m0"]["value"]
        )
        * 8
        * 3600
    )  # seconds
    itime = qa.quantity(itime, "s")

    if endtime > netime:
        # TODO: Re-validate the logic
        # This will throw error as "etime" is not defined
        timeloop = int(endtime / netime)
        starttime = -netime / 2
        etime = +netime / 2
        timeloop = int(etime / netime)
        for i in range(0, timeloop):
            reftime["m0"]["value"] = reftime["m0"]["value"] + 1
            sm.settimes(
                integrationtime=itime,
                usehourangle=False,
                referencetime=reftime,
            )
            sm.observe(
                sourcename=sname,
                spwname="BAND 2",
                starttime=qa.quantity(starttime, "s"),
                stoptime=qa.quantity(etime, "s"),
            )
            nfld = 1
            sm.setdata(spwid=[0], fieldid=range(0, nfld))

    else:
        sm.settimes(
            integrationtime=itime, usehourangle=False, referencetime=reftime
        )
        sm.observe(
            sourcename=sname,
            spwname="BAND 2",
            starttime=qa.quantity(starttime, "s"),
            stoptime=qa.quantity(endtime, "s"),
        )

    if usemod == "im":
        # Predict from a model image
        sm.predict(imagename=imagename, incremental=True)
    elif usemod == "com":
        # Predict from a model image
        sm.predict(complist=clname, incremental=True)
    elif usemod == "both":
        # Predict from a model image and component list
        sm.predict(complist=clname, imagename=imagename, incremental=True)
    else:
        sm.predict(complist=clname, incremental=True)
    sm.close()  # Close the tool
    sm.done()
    # Unflag everything (unless you care about elevation/shadow flags)
    flagdata(vis=msname, mode="unflag")

## Simulate Sky model


The method defined below sinulates a point source with a 1.5 Jy flux density at .95 GHz and with a spectral index of 0.0. See CASA documentation to specify more complex sky models.

### Make an empty CASA image

In [6]:
def makeEmptyImage(imname_true, nchan, ref_freq, freq_incr, cell_size):
    # Define the center of the image
    radir = "15h00m00s"
    decdir = "-30d00m00s"

    # Make the image from a shape
    ia.close()
    ia.fromshape(imname_true, [256, 256, 1, nchan], overwrite=True)

    # Make a coordinate system
    cs = ia.coordsys()
    cs.setunits(["rad", "rad", "", "Hz"])
    cell_rad = qa.convert(qa.quantity(cell_size), "rad")["value"]
    cs.setincrement([-cell_rad, cell_rad], "direction")
    cs.setreferencevalue(
        [
            qa.convert(radir, "rad")["value"],
            qa.convert(decdir, "rad")["value"],
        ],
        type="direction",
    )
    cs.setreferencevalue(ref_freq, "spectral")
    cs.setreferencepixel([0], "spectral")
    cs.setincrement(freq_incr, "spectral")

    # Set the coordinate system in the image
    ia.setcoordsys(cs.torecord())
    # ia.setrestoringbeam(major='18arcsec', minor='18arcsec', pa='0deg')
    ia.setbrightnessunit("Jy/pixel")
    ia.set(0.0)
    ia.close()

### Make a component list and evaluate it onto a CASA image

In [4]:
def makeCompList(clname_true):

    # Make sure the cl doesn't already exist. The tool will complain otherwise.
    remove_if_exists(clname_true)
    cl.done()

    # Add sources, one at a time.
    # Call multiple times to add multiple sources. ( Change the 'dir', obviously )

    # cl.addcomponent(dir='J2000 15h00m00s -30d00m00s',flux=4.0,fluxunit='Jy', freq='150MHz', shape='gaussian', majoraxis="20arcsec", minoraxis='20arcsec')

    cl.addcomponent(
        dir="J2000 15h00m00s -30d00m00s",
        # For a gaussian, this is the integrated area.
        flux=1.50,
        fluxunit="Jy",
        freq="0.96GHz",
        #                   shape='point',       ## Point source
        shape="gaussian",  # Gaussian
        majoraxis="20arcsec",
        minoraxis="20arcsec",
        spectrumtype="spectral index",
        index=0.0,
    )

    # Print out the contents of the componentlist
    print("Contents of the component list")
    print(cl.torecord())
    # Save the file
    cl.rename(filename=clname_true)
    cl.done()

In [None]:
def evalCompList(clname, imname):
    ##  Evaluate a component list
    cl.open(clname)
    ia.open(imname)
    ia.modify(cl.torecord(), subtract=False)
    ia.close()
    cl.done()

### Add spectral line

In [55]:
def editPixels(imname, freq_coords):
    ## Edit pixel values directly
    ia.open(imname)
    pix = ia.getchunk()
    shp = ia.shape()
    mean = freq_coords.mean()  # np.random.uniform(freq_start, freq_end)
    std = 0.5  # np.abs(np.random.normal(0, 0.5))
    spectral_height = (200e-3) * np.exp(
        -1 / 2 * ((freq_coords - mean) / std) ** 2
    )
    pix[int(shp[0] / 2), int(shp[1] / 2), 0, :] = (
        pix[int(shp[0] / 2), int(shp[1] / 2), 0, :] + +spectral_height
    )
    # Add a spectral line in channel 1
    ia.putchunk(pix)
    ia.close()

##  Run the simulation

### Setting up input parameters

In [5]:
# ************* Inputs for the Simulation ***************#
telescope_file = "./ska1-mid.cfg"
output_dir = "./data_configurable2"

spectral_cube = os.path.join(output_dir, "spectral_cube")
cl_name = os.path.join(output_dir, "one_point.cl")
msname = os.path.join(output_dir, "sim_mid_msfile.ms")

# modify this if you want to use custom image
casa_image = spectral_cube + ".image"

# fits image generated if generate_sky_model is True
fits_image = spectral_cube + ".fits"

# Image params
cell_size = 6.0 * au.arcsecond

# SKA observation parameters
freq_start = 950 * au.MHz
freq_end = 970 * au.MHz
chan_res = 0.100 * au.MHz

# set channels
nchan = int((freq_end - freq_start) / chan_res) + 1

# qq = 1 # only point sources
observation_date = ["2000/01/02/05:00:00"]  # observation date
integration_time_sec = 120.0  # Integration time in seconds
start_time_sec = 0.0  # start time in seconds
end_time_sec = 14400.0  # End time in seconds

### Using custom spectral cube

In [6]:
# If generate_sky_model is True, the notebook will generate a casa (and fits) spectral cube
# Based on the componenent specified in above functions
# The generated cube is then used to generate the visibility data

generate_sky_model = True

# To use custom casa spectral cube as input, set generate_sky_model to False
# Then overwrite the casa_image variable with path to your image cube

# casa_image =

### Generate sky model

In [7]:
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

In [None]:
if generate_sky_model:
    makeCompList(cl_name)
    makeEmptyImage(
        casa_image,
        nchan,
        remove_whitespace_in_astropy_unit(freq_start),
        remove_whitespace_in_astropy_unit(chan_res),
        remove_whitespace_in_astropy_unit(cell_size),
    )
    # Evaluate the component list onto the CASA image
    evalCompList(cl_name, casa_image)

    exportfits(imagename=casa_image, fitsimage=fits_image, overwrite=True)

    freq_coords = np.linspace(freq_start, freq_end, nchan)
    editPixels(casa_image, freq_coords.value)

### Run simulation

In [None]:
# TODO: Are deltafreq and freqresolution same?
# Currently we are passing channel resolution to both

simulate(
    msname=msname,
    ant_config=telescope_file,
    clname=None,
    imagename=casa_image,
    itime=integration_time_sec,
    starttime=start_time_sec,
    endtime=end_time_sec,
    odate=observation_date,
    freq=remove_whitespace_in_astropy_unit(freq_start),
    deltafreq=remove_whitespace_in_astropy_unit(chan_res),
    freqresolution=remove_whitespace_in_astropy_unit(chan_res),
    nchannels=nchan,
    usemod="im",
)  # Predict from a model image