# Test

In [1]:
import nonlinear_dce
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DummyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

dummy_model = DummyModel()

dagma = nonlinear_dce.DagmaDCE(model=dummy_model, use_mle_loss=True)
dims = [5,7,1]
dagma_mlp = nonlinear_dce.DagmaMLP_DCE(dims=dims)
n, d = 100, dims[0]
target = torch.randn(n, d)
output = target + 0.1 * torch.randn(n, d)

A = torch.randn(d, d)
Sigma = A @ A.T + 1e-3 * torch.eye(d)

## MLE loss

In [3]:
loss_value = dagma.mle_loss(output=output, target=target, Sigma=Sigma)
print("MLE loss:", loss_value.item())

MLE loss: 7.3140552211697


In [4]:
tmp = 0
diff = target - output
for i in range(0,n):
    tmp += diff[i, :]@ torch.inverse(Sigma)@diff[i, :].unsqueeze(1)

logdet = torch.logdet(Sigma)
loss_value2 = tmp/n + logdet
print("MLE loss:", loss_value2.item())

MLE loss: 7.314055221169701


## h_func

In [5]:
W1 = torch.rand(d, d)   # values in [0,1)
W2 = torch.rand(d, d)
h_value =dagma_mlp.h_func(W1, W2)
print("h_value:", h_value.item())

h_value: 8.84974343251615


## Log Cholesky

In [8]:
A = torch.randn(d, d)             
Sigma = A @ A.T + 1e-3 * torch.eye(d)
print("Sigma: ", Sigma)

Sigma:  tensor([[ 9.3026,  0.3791,  4.0431, -0.6954, -1.1468],
        [ 0.3791,  3.2425,  3.3744,  3.3295,  0.5797],
        [ 4.0431,  3.3744,  5.1537,  3.5565, -0.5255],
        [-0.6954,  3.3295,  3.5565,  4.9086, -0.8183],
        [-1.1468,  0.5797, -0.5255, -0.8183,  2.7815]])


In [9]:
M = nonlinear_dce.reverse_SPDLogCholesky(Sigma)
print("M: ", M)

M:  tensor([[ 1.1151,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.1243,  0.5858,  0.0000,  0.0000,  0.0000],
        [ 1.3256,  1.7867, -0.7945,  0.0000,  0.0000],
        [-0.2280,  1.8692,  1.1488, -1.5735,  0.0000],
        [-0.3760,  0.3487, -1.4390,  0.4692, -0.7401]])


In [10]:
Sigma = nonlinear_dce.SPDLogCholesky(M)
print("Sigma: ", Sigma)

Sigma:  tensor([[ 9.3026,  0.3791,  4.0431, -0.6954, -1.1468],
        [ 0.3791,  3.2425,  3.3744,  3.3295,  0.5797],
        [ 4.0431,  3.3744,  5.1537,  3.5565, -0.5255],
        [-0.6954,  3.3295,  3.5565,  4.9086, -0.8183],
        [-1.1468,  0.5797, -0.5255, -0.8183,  2.7815]])
