<a href="https://colab.research.google.com/github/rubyvanrooyen/NIFTyworkshop/blob/master/NIFTy_Example_CorrelatedFieldModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone -b NIFTy_6 --single-branch https://gitlab.mpcdf.mpg.de/ift/nifty.git
!pip install ./nifty

In [None]:
import nifty6 as ift
import numpy
import numpy as np
from matplotlib import pylab

In [None]:
# we model the power spectrum as a field
# -- learn the power spectrum alongside with the actual field
N_pixels = 512     # Number of pixels
position_space = ift.RGSpace([N_pixels])

In [None]:
# set up
# see model notebook 'getting_started_4_CorrelatedFields.ipynb' for number selection

cfmaker = ift.CorrelatedFieldMaker.make(
        offset_mean =      0.0,  # 0.
        offset_std_mean = 1e-3,  # 1e-3
        offset_std_std =  1e-6,  # 1e-6
        prefix = '')

fluctuations_dict = {
    # Amplitude of field fluctuations
    'fluctuations_mean':   2.0,  # 1.0
    'fluctuations_stddev': 1.0,  # 1e-2

    # Exponent of power law power spectrum component
    'loglogavgslope_mean': -2.0,  # -3.0
    'loglogavgslope_stddev': 0.5,  #  0.5

    # Amplitude of integrated Wiener process power spectrum component
    'flexibility_mean':   2.5,  # 1.0
    'flexibility_stddev': 1.0,  # 0.5

    # How ragged the integrated Wiener process component is
    'asperity_mean':   0.5,  # 0.1
    'asperity_stddev': 0.5  # 0.5
}
cfmaker.add_fluctuations(position_space, **fluctuations_dict)
# nifty operator that gives out these fields
correlated_field = cfmaker.finalize()

In [None]:
# feed it some random gaussian noise
main_sample = ift.from_random(correlated_field.domain)
print("model domain keys:", correlated_field.domain.keys())
# look at some prior samples
plot = ift.Plot()
plot.add(correlated_field(main_sample))
plot.output()

In [None]:
# look at the power spectrum
A = cfmaker.amplitude
# print(A.domain)
# power spectrum is square of amplitude
pspec = A**2
# print(pspec.domain)

# plot the power spectrum options
plot = ift.Plot()
plot.add([pspec(ift.from_random(pspec.domain)) for _ in range(20)])
plot.output()

In [None]:
# sigmoid function to define the signal
# -- applies sigmoid non-linearity to the operator
signal = ift.sigmoid(correlated_field)

In [None]:
# Build the line-of-sight response and define signal response
## set up response function -- using line of sight response
# LOS inputs are random and need to be lists

## random line-of-sight response
# rng == random number generator
n_los=100
LOS_starts = list(ift.random.current_rng().random((n_los, 1)).T)
LOS_ends = list(ift.random.current_rng().random((n_los, 1)).T)

# ## radial lines of sight
# LOS_starts = list(ift.random.current_rng().random((n_los, 1)).T)
# LOS_ends = list(0.5 + 0*ift.random.current_rng().random((n_los, 1)).T)

# LOSResponse(position_space (domain), los_starts, los_ends (start and ending values))
R = ift.LOSResponse(position_space, starts=LOS_starts, ends=LOS_ends)
print(R)
print(R.target)
# visualise where the lines of sight are located (randomly distributed)
ift.single_plot(R.adjoint(ift.full(R.target, 1.)))

# signal_respones == signal followed by the response = R(signal) = R @ signal
signal_response = R(signal)
data_space = R.target

In [None]:
# need to specify some noise
noisevar = 0.01
N = ift.ScalingOperator(data_space, noisevar)
noise = N.draw_sample_with_dtype(np.float64)

In [None]:
# need to specify noiseless data
print(signal.domain)
print(signal_response.domain)
# what is the difference between the signal and the signal_response

mock_position = ift.from_random(signal.domain, 'normal')
noiseless_data = signal_response(mock_position)
# Generate mock signal and data
data = noiseless_data + noise
#print(data)

In [None]:
# set up the likelihood domain and target
## N the covariance matrix
# GaussianEnergy(mean=data, inverse_covariance=N.inverse) @ (R @ signal)
likelihood = (ift.GaussianEnergy(mean=data, inverse_covariance=N.inverse) @ signal_response)
# likelihood energy = negative logarithms of a propability density (=info hamiltonian)

In [None]:
# define hamiltonian
# needs to know how to draw the samples (this has to be approximated as well)
#ic_sampling = ift.AbsDeltaEnergyController(name='Sampling', deltaE=0.05, iteration_limit=100)
ic_sampling = ift.AbsDeltaEnergyController(deltaE=0.05, iteration_limit=100)
H = ift.StandardHamiltonian(likelihood, ic_sampling)

In [None]:
# drawing MGVI samples means running one wiener filter, needs a conjugate_gradient run
# initialise the minimization
initial_mean = ift.MultiField.full(H.domain, 0.)
# alternatively: initial_mean = 0.1 * ift.from_random(H.domain)

In [None]:
# plot to see what is going on -- look at the signal 
plot = ift.Plot()
plot.add(signal(mock_position), title='Ground Truth')
# information source == (R.ajoint @ N.inverse)(data)
plot.add(R.adjoint_times(data), title='Data')
# A.force(mock_position) -- force ground truth power spectrum dimensions
plot.add([A.force(mock_position)], title='Amplitude')
plot.output(ny=1, nx=3, xsize=24, ysize=6)

In [None]:
# Loop section
## draws KL samples during initialisation and return an Energy (MetricGaussianKL)
## this energy is then minimised using a minimisation algo

# define minimiser
#ic_newton = ift.AbsDeltaEnergyController(name='Newton', deltaE=0.5, iteration_limit=35)
ic_newton = ift.AbsDeltaEnergyController(deltaE=0.5, iteration_limit=35)
minimizer = ift.NewtonCG(ic_newton)
# add name='Newton' to minimiser if you want to watch it minimizing
# perhaps start with ieteration limit = 10 (35 is at higher end)

# number of samples used to estimate the KL
N_samples = 20
# N_samples == number of lines to sample the propability mask

mean = initial_mean
# Draw new samples to approximate the KL five times
for i in range(5):

    # Draw new samples and minimize KL
    KL = ift.MetricGaussianKL(mean, H, N_samples)
    # The N_samples are random samples and drawn from Gaussian probability distribution
    # -- Gaus approx to current mean position that we are updating
    #    (the local approximation of the true posterior distribution)

    # minimiser gives out new KL object and error
    KL, convergence = minimizer(KL)
    # update the mean to a new mean position = KL.position
    mean = KL.position

    # Plot current reconstruction
    plot = ift.Plot()
    plot.add(signal(KL.position), title="Latent mean")
    plot.add([A.force(KL.position + ss) for ss in KL.samples],
                title="Samples power spectrum")
    plot.output(ny=1, ysize=6, xsize=16)

In [None]:
# add some posterior analysis and draw new smaples
KL = ift.MetricGaussianKL(mean, H, N_samples)
# operator to calculate running mean
sc = ift.StatCalculator()
# draw posterior samples
for sample in KL.samples:
    sc.add(signal(sample + KL.position))

In [None]:
# plot the posterior mean (sc.mean)
# and look at the standard deviation (sc.var)
plot = ift.Plot()
plot.add(sc.mean, title="Posterior Mean")
plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")
plot.output()

In [None]:
plot = ift.Plot()
plot.add(sc.mean, title="Posterior Mean")
plot.add(ift.sqrt(sc.var), title="Posterior Standard Deviation")

powers = [A.force(s + KL.position) for s in KL.samples]
plot.add(
    powers + [A.force(mock_position),
              A.force(KL.position)],
    title="Sampled Posterior Power Spectrum",
    linewidth=[1.]*len(powers) + [3., 3.])
plot.output(ny=1, nx=3, xsize=24, ysize=6)

In [None]:
# what are the differences between the various implementations
# -- how does the mappings work?

# Get signal data and reconstruction data
# recontruct and compare
fig, (ax0, ax1, ax2) = pylab.subplots(3, 1, figsize=(15, 7), facecolor='white')
# corrupt measured signal
ax0.plot(data.val, 'k.-', label="Data")
ax0.legend()

ax1.plot(data.val, 'k.', label="Data")
# ground truth
ax1.plot(signal_response(mock_position).val, label='Ground Truth')
# R(m) posterior mean -- recontructed signal
ax1.plot(signal_response(KL.position).val, 'k', label="Reconstruction")
ax1.legend()

ax2.plot(R.adjoint_times(data).val, 'k.', label="Data(?)")
ax2.plot(signal(mock_position).val, label='Ground Truth(?)')
ax2.plot(sc.mean.val, label='Reconstruction(?)')
ax2.plot(signal(KL.position).val, label="Reconstruction(?)")
ax2.legend()

pylab.show()