In [1]:
using Plots, LinearAlgebra, DifferentialEquations, StatsBase, Distributions, DifferentialEquations.EnsembleAnalysis
using Suppressor, Printf, Integrals, SymPy

In [2]:
plotlyjs()

Plots.PlotlyJSBackend()

In [114]:
function plot_pdf(D, min_=-5, max_=5, num=40)
    x = range(min_, max_, num)
    y = range(min_, max_, num)
    
    Dz = map(z -> pdf(D, collect(z)), Iterators.product(x, y))
    p = plot(x, y, Dz, st=:surface, xlims=(min_, max_), ylims=(min_, max_));
    display(p)
end;

function getGaussians(A, b, u0)
    
    V = svd(A).V
    
    function rho_t(t)
        mu = pinv(A)*b
        Sig = pinv(A'*A);
        return MvNormal(exp(-A'*A*t)*V*V'*u0+(I-exp(-A'*A*t))*V*V'*mu+(I-V*V')*u0, Sig*(I-exp(-2*A'*A*t))+2t*(I-V*V')) 
    end
    
    function rho_c(c)
        return MvNormal(pinv(A'*A+c*V*V')*(A'*b+c*u0)+(I-V*V')*u0, pinv(A'*A+c*V*V')+(1/c)*(I-V*V'))
    end
    
    return rho_t, rho_c
    
end;

function makeKL(A, b, u0)
    
    d = length(u0)
    V = svd(A).V
    mu = pinv(A)*b
    Sig = pinv(A'*A);

    function KL1(t, c)
        mean1 = exp(-A'*A*t)*V*V'*u0+(I-exp(-A'*A*t))*V*V'*mu+(I-V*V')*u0
        cov1 = Sig*(I-exp(-2*A'*A*t))+2t*(I-V*V')
        mean2 = pinv(A'*A+c*V*V')*(A'*b+c*u0)+(I-V*V')*u0
        cov2 = pinv(A'*A+c*V*V')+(1/c)*(I-V*V')
        
        term1 = log(det(cov1))-log(det(cov2))
        term2 = (mean2-mean1)'*inv(cov1)*(mean2-mean1)
        term3 = tr(inv(cov1)*cov2)
        
        return 0.5*(term1+term2+term3-d)
        
        
    end
    
    function KL2(t, c)
        mean1 = exp(-A'*A*t)*V*V'*u0+(I-exp(-A'*A*t))*V*V'*mu+(I-V*V')*u0
        cov1 = Sig*(I-exp(-2*A'*A*t))+2t*(I-V*V')
        mean2 = pinv(A'*A+c*V*V')*(A'*b+c*u0)+(I-V*V')*u0
        cov2 = pinv(A'*A+c*V*V')+(1/c)*(I-V*V')
        
        term1 = log(det(cov2))-log(det(cov1))
        term2 = (mean2-mean1)'*inv(cov2)*(mean2-mean1)
        term3 = tr(inv(cov2)*cov1)
        
        return 0.5*(term1+term2+term3-d)
        
    end
    return KL1, KL2
    
end;

function terms(A, b, u0, t, c)
    
    d = length(u0)
    V = svd(A).V
    mu = pinv(A)*b
    Sig = pinv(A'*A);
    
    mean1 = exp(-A'*A*t)*V*V'*u0+(I-exp(-A'*A*t))*V*V'*mu+(I-V*V')*u0
    cov1 = Sig*(I-exp(-2*A'*A*t))+2t*(I-V*V')
    mean2 = pinv(A'*A+c*V*V')*(A'*b+c*u0)+(I-V*V')*u0
    cov2 = pinv(A'*A+c*V*V')+(1/c)*(I-V*V')

    term1 = log(det(cov1))-log(det(cov2))
    term2 = tr(inv(cov1)*cov2)
    
    return term1, term2
    
end;

In [268]:
d = 1
m = 1

A = ones(m, d);
V = svd(A).V

b = ones(m);

u0 = pinv(A)*b+(I-V*V')*randn(d);

# d = 2
# A = diagm(rand(d).+1)
# b = ones(d)
# u0 = pinv(A)*b;

In [269]:
rho_t, rho_c = getGaussians(A, b, u0);
KL1, KL2 = makeKL(A, b, u0);

In [270]:
ts = 0:0.01:1

plot(ts[2:end], map(t -> KL1(t, 1/(2*t)), ts[2:end]))

In [271]:
allTerms = map(t -> terms(A, b, u0, t, 1/(2*t)), ts[2:end]);
terms1, terms2 = collect(zip(allTerms...));

In [272]:
plot(ts[2:end], [terms1...])

In [273]:
ts[2:end][argmax(terms1)]

0.9

In [266]:
ts[2:end][argmin(terms2)]

0.61

In [267]:
plot(ts[2:end], [terms2...].-d)

In [120]:
plot(ts[2:end], 0.5*([terms1...]+[terms2...].-d))

In [13]:
# x = -5:0.1:5
# y = -5:0.1:5


# t = 0.5
# Dz1 = map(z -> pdf(rho_t(t), collect(z)), Iterators.product(x, y))
# Dz2 = map(z -> pdf(rho_c(1/(2t)), collect(z)), Iterators.product(x, y))
# p = plot(x, y, Dz1, st=:surface, xlims=(-5, 5), ylims=(-5, 5));
# wireframe!(x, y, Dz2)
# display(p)

In [201]:
t = 0.51
cov1 = pinv(A'*A)*(I-exp(-2*A'*A*t))+2t*(I-V*V');
cov2 = pinv(A'*A+V*V'/(2t))+(2t)*(I-V*V');

In [207]:
tr(inv(cov1)*cov2)

99.9451857490585

In [217]:
sum((1 ./ ((1 .- exp.(-2t*svdvals(A).^2))./svdvals(A).^2)) .* (2t./(2t*svdvals(A).^2 .+ 1))    ) + d -m

99.94518574905845