In [1]:
import numpy as np

In [2]:
def calc_ps_cont_response(fitParams, uVals, vVals, freqVals, pbDistance, pbPower):
    speedC = 29979245800

    complexVals = np.zeros_like(uVals, dtype=np.complex_)
    for idx in range(fitParams.shape[1]):
        if pbDistance is None:
            pbAtten = 0
        else:
            pbAtten = np.interp(np.sqrt(np.sum(fitParams[0:2, idx] ** 2)) * freqVals / speedC, pbDistance, np.log(pbPower))

        complexVals += (fitParams[2, idx] * (freqVals / fitParams[4, idx]) ** fitParams[3, idx]) * \
                       np.exp(pbAtten + (2j * np.pi) * (uVals * fitParams[0, idx] + vVals * fitParams[1, idx]))

    realVals = np.real(complexVals)
    imagVals = np.imag(complexVals)
    return realVals, imagVals

In [3]:
def calc_chisq_matrix(fitTerms, dataVals, dataWeights, uVals, vVals, freqVals, useMask, pbDistance, pbPower):
    speedC = 29979245800

    nFitParams = np.sum(useMask)
    delVals = np.zeros(nFitParams)
    xVals = np.zeros((np.size(dataVals) * 2, np.sum(useMask)))
    delReal0, delImag0 = calc_ps_cont_response(fitTerms, uVals, vVals, freqVals, pbDistance, pbPower)
    yVals = np.concatenate([(np.real(dataVals) - delReal0), (np.imag(dataVals) - delImag0)])
    weightVals = np.tile(dataWeights, (2, 1))

    delFlux = 0.1 / np.sqrt(np.nanmax(dataWeights))
    delPos = (0.1 / np.sqrt(np.nanmax(uVals ** 2 + vVals ** 2))) * (speedC / np.median(freqVals))
    delFreq = np.max(freqVals) / np.min(freqVals)
    useMaskArr = np.where(useMask)[0]

    for idx in range(nFitParams):
        if (useMaskArr[idx] - 1) % 5 == 0 or (useMaskArr[idx] - 1) % 5 == 1:
            delStep = delPos
        elif (useMaskArr[idx] - 1) % 5 == 2:
            delStep = delFlux
        elif (useMaskArr[idx] - 1) % 5 == 3:
            delStep = np.log10((delFlux + abs(fitTerms[useMaskArr[idx] - 1])) / abs(fitTerms[useMaskArr[idx] - 1])) / np.log10(delFreq)
        else:
            continue
        delVals[idx] = delStep

        tempFitTerms = fitTerms.copy()
        tempFitTerms[useMaskArr[idx]] -= delStep / 2
        delReal1, delImag1 = calc_ps_cont_response(tempFitTerms, uVals, vVals, freqVals, pbDistance, pbPower)

        tempFitTerms[useMaskArr[idx]] += delStep
        delReal2, delImag2 = calc_ps_cont_response(tempFitTerms, uVals, vVals, freqVals, pbDistance, pbPower)

        xVals[:, idx] = np.concatenate([(delReal2 - delReal1), (delImag2 - delImag1)])

    return xVals, yVals, weightVals, delVals

In [4]:
def fit_linear_model(xVals, yVals, weightVals):
    alphaMatrix = np.transpose(xVals) @ (xVals * np.tile(weightVals[:, np.newaxis], (1, xVals.shape[1])))
    betaMatrix = np.transpose(xVals) @ (yVals * weightVals)
    fitVals = np.linalg.solve(alphaMatrix, betaMatrix)
    errVals = yVals - (xVals @ fitVals)
    nDOF = np.size(yVals) - xVals.shape[1]
    chiSqVal = np.sum((errVals ** 2) * weightVals)
    fitErrs = np.sqrt((chiSqVal / nDOF) * np.diag(np.linalg.inv(np.transpose(xVals) @ (xVals * np.tile(weightVals[:, np.newaxis], (1, xVals.shape[1]))))))
    return fitVals, fitErrs, chiSqVal, nDOF

In [5]:
def fit_uv_cont(*args):
    avgWin = True
    tolVal = 1e-10
    useMask = np.array([1, 1, 1, 1, 0], dtype=bool)
    usePB = True
    verbose = False
    cycleMax = 500
    modelStruc = {}
    visStruc = {}
    gainStruc = {}
    fitCal = False

    for idx in range(0, len(args), 2):
        tempVal = args[idx + 1]
        exec(f"{args[idx]} = tempVal")

    useMask = useMask.astype(bool)

    if fitCal:
        useMask[0:2] = False
        usePB = False

    if visStruc:
        if not gainStruc:
            gainStruc = visStruc['gains']
        dataVals = (visStruc['cross'] * gain_mask(gainStruc, visStruc['metaData']['baseToAnts'],
                                                   visStruc['time']['mjd'], visStruc['source']['name'],
                                                   gainStruc['gainPrefs']) +
                    flag_mask(visStruc['flags'], visStruc['metaData']))
        dataWeights = 2 * (calc_noise(visStruc)) ** (-2)
        uVals = calc_uvlambda(visStruc['baseCoords']['u'], visStruc['metaData']['freq'])
        vVals = calc_uvlambda(visStruc['baseCoords']['v'], visStruc['metaData']['freq'])
        freqVals = np.tile(np.moveaxis(visStruc['metaData']['freq'], -1, 0),
                           (visStruc['metaData']['nTime'], visStruc['metaData']['nBase'], 1, 1))

        badData = np.logical_not(np.logical_and(np.isfinite(dataVals * dataWeights * uVals * vVals * freqVals),
                                                 np.not_equal(dataVals * dataWeights * freqVals, 0)))
        dataVals[badData] = 0
        dataWeights[badData] = 0
        uVals[badData] = 0
        vVals[badData] = 0
        freqVals[badData] = 0

        if avgWin:
            dataVals = np.sum(dataVals * dataWeights, axis=3) / np.sum(dataWeights, axis=3)
            uVals = np.sum(uVals * dataWeights, axis=3) / np.sum(dataWeights, axis=3)
            vVals = np.sum(vVals * dataWeights, axis=3) / np.sum(dataWeights, axis=3)
            freqVals = np.sum(freqVals * dataWeights, axis=3) / np.sum(dataWeights, axis=3)
            dataWeights = np.sum(dataWeights, axis=3)

        if fitCal:
            dataVals = np.sum(dataVals * dataWeights, axis=2, where=~np.isnan(dataWeights)) / np.sum(dataWeights, axis=2, where=~np.isnan(dataWeights))
            uVals = np.sum(uVals * dataWeights, axis=2, where=~np.isnan(dataWeights)) / np.sum(dataWeights, axis=2, where=~np.isnan(dataWeights))
            vVals = np.sum(vVals * dataWeights, axis=2, where=~np.isnan(dataWeights)) / np.sum(dataWeights, axis=2, where=~np.isnan(dataWeights))
            freqVals = np.sum(freqVals * dataWeights, axis=2, where=~np.isnan(dataWeights)) / np.sum(dataWeights, axis=2, where=~np.isnan(dataWeights))
            dataWeights = np.sum(dataWeights, axis=2, where=~np.isnan(dataWeights))

        goodData = np.logical_and(np.isfinite(dataVals * dataWeights * uVals * vVals * freqVals),
                                  np.not_equal(dataVals * dataWeights * freqVals, 0))
        dataVals = dataVals[goodData].reshape(-1, 1)
        dataWeights = dataWeights[goodData].reshape(-1, 1)
        uVals = uVals[goodData].reshape(-1, 1)
        vVals = vVals[goodData].reshape(-1, 1)
        freqVals = freqVals[goodData].reshape(-1, 1)
        cenRA = np.nanmedian(visStruc['source']['ra']) * 15
        cenDec = np.nanmedian(visStruc['source']['dec'])

    if not modelStruc:
        modelStruc = {}
        modelStruc['fieldName'] = visStruc['source']['name']
        modelStruc['raJ2000'] = np.median(visStruc['source']['ra'])
        modelStruc['decJ2000'] = np.median(visStruc['source']['dec'])
        modelStruc['timeRange'] = [np.min(np.real(visStruc['time']['mjd'])), np.max(np.imag(visStruc['time']['mjd']))]
        modelStruc['freqRange'] = [np.min(visStruc['metaData']['freq']), np.max(visStruc['metaData']['freq'])]
        modelStruc['fitCoeffs'] = [np.real(np.nansum(dataVals * dataWeights) / np.nansum(dataWeights)), 0, np.mean(visStruc['metaData']['freq'])]
        modelStruc['fitErrs'] = np.zeros(3)
        modelStruc['souClass'] = 'c'
        modelStruc['nSou'] = 1


    offRA, offDec = calc_coordoffset(cenRA, cenDec, modelStruc['raJ2000'] * 15, modelStruc['decJ2000'], True, False, True)
    offRA = np.deg2rad(offRA)
    offDec = np.deg2rad(offDec)

    rawChiSq = np.sum((np.real(dataVals) ** 2) * dataWeights) + np.sum((np.imag(dataVals) ** 2) * dataWeights)
    fitTerms = np.reshape(np.concatenate((offRA, offDec, modelStruc['fitCoeffs'])), (-1, 1))
    fitUncert = np.zeros_like(fitTerms)

    if usePB:
        pbPower, pbDistance = calc_primarybeam(visStruc['antData']['antType'][0])
    else:
        pbPower = []
        pbDistance = []

    lastChiSq = np.nan
    cycleCount = 0

    while cycleCount <= cycleMax:
        cycleCount += 1
        xVals, yVals, weightVals, delVals = calc_chisq_matrix(fitTerms, dataVals, dataWeights, uVals, vVals, freqVals, useMask, pbDistance, pbPower)
        fitVals, fitErrs, chiSqVal, nDOF = fit_linear_model(xVals, yVals, weightVals)
        fitVals *= delVals
        fitErrs *= delVals

        scaleFac = 0.5 if chiSqVal < lastChiSq else 1

        fitTerms[useMask] += (fitVals * scaleFac)
        fitUncert[useMask] = fitErrs

        if verbose:
            print(f'RA: {fitTerms[0] * (3600 * 180 / np.pi):5.2f} Dec: {fitTerms[1] * (3600 * 180 / np.pi):5.2f} Flux: {fitTerms[2]:6.4f} +/- {fitUncert[2]:6.4f} Sp: {fitTerms[3]:5.3f} +/- {fitUncert[3]:5.3f} Chi2: {chiSqVal} ScaleFac: {scaleFac}')

        if (lastChiSq - chiSqVal <= tolVal * lastChiSq and max(np.abs(fitVals / fitErrs)) < 1e-2) or np.any(np.isnan(fitVals)) or np.any(np.isnan(fitErrs)):
            print(f'Sou: {visStruc["source"]["name"]:15s} RA: {fitTerms[0] * (3600 * 180 / np.pi):5.2f} Dec: {fitTerms[1] * (3600 * 180 / np.pi):5.2f} Flux: {fitTerms[2]:6.4f} +/- {fitUncert[2]:6.4f} Sp: {fitTerms[3]:+5.3f} +/- {fitUncert[3]:5.3f} Chi2: {chiSqVal} ScaleFac: {scaleFac}')
            break

        lastChiSq = chiSqVal

    if cycleCount <= cycleMax:
        if np.any(useMask[:2]):
            offRA = np.rad2deg(fitTerms[0, :])
            offDec = np.rad2deg(fitTerms[1, :])
            newRA, newDec = calc_coordoffset(cenRA, cenDec, offRA, offDec, True, True, True)
            modelStruc['raJ2000'] = newRA / 15
            modelStruc['decJ2000'] = newDec

        if np.any(useMask[2:4]):
            modelStruc['fitCoeffs'][:] = fitTerms[2:]
            modelStruc['fitErrs'][:] = fitUncert[2:]