In [2]:
import numpy as np
from matplotlib import pyplot as plt

from importlib import reload
import src.probability
reload(src.probability)
reload(src.multiplier)
reload(src.observations)
from src.probability import Prior, Posterior
from src.multiplier import FourierMultiplier
from src.forward import Heat
from src.observations import PointObservation, DiagObservation


%matplotlib

Using matplotlib backend: QtAgg




In [105]:
sig = 5e-2

N = 200
L = 1
time = 3e-2
alpha = 1.
gamma = -1.

## DST /DCT - discrete (co)sine transform. 
## Corresponds to a homogeneous Dirichlet/ Neumann boundary condition.
## Clusterization also works with Neumann boundary, I just chose to present 
## Dirichlet boundary
dst, dct = {}, {}
for transform in  ['dst', 'dct']:
    delta = 0. if transform == 'dst' else 0.5
    fwd = Heat(N=N, L=L, transform=transform, alpha=alpha, time=time)
    pr = Prior(N=N, L=L, transform=transform, gamma=gamma, delta=delta)
    post = Posterior(fwd=fwd,
                     prior=pr,
                     sigSqr=sig**2,
                     L=L,
                     N=N,
                     transform=transform)
    dic = dst if transform == 'dst' else dct
    for m in range(2, 6):
        res = post.optimize(m=m, n_iterations=250)
        design = res['x']
        dic[m] = design
        print(transform, m, design)


dst 2 [0.31163064 0.68336936]
dst 3 [0.31319488 0.31319488 0.68985848]
dst 4 [0.3066907  0.30669071 0.68830929 0.68830932]
dst 5 [0.28343313 0.28343322 0.49749991 0.71156683 0.71156684]
dct 2 [0.     0.9975]
dct 3 [0.         0.49750556 0.9975    ]
dct 4 [0.         0.49469032 0.9975     0.9975    ]
dct 5 [0.         0.         0.49750558 0.9975     0.9975    ]


In [102]:
plt.close('all')

fig, ax = plt.subplots(figsize=(10,5))

for m, array in dst.items():
    vals = np.repeat(m, len(array))
    ax.scatter(array, vals, s=0)

    for i, val in enumerate(array):
        ax.annotate(str(i+1), xy=(val, m+np.random.randn()*0), ha='center', va='center', fontsize=20)
            
ax.set_xlabel('Measurement Location', fontsize=22)
ax.set_ylabel('No. of Measuremnts', fontsize=22)
plt.tight_layout()

plt.savefig("latex/example.eps")