In [None]:
import spdnn
import eigenfunctions as eF
import eigenoptim as eOptim
import matplotlib.pyplot as plt

In [None]:
## Generate Example features
s1 = 5
s2 = 3
Xdat = torch.rand(s1, s1)
X = Xdat @ Xdat
tgt = torch.rand(s2, s2)
tgt = tgt @ tgt
Wdat = torch.rand(s1, s1)
e, v = torch.eig(Wdat, eigenvectors=True)
Wdat = v[:s2]
W = Wdat.clone().detach().requires_grad_(True)

## BiMap Example
___Demonstrate that output features are SPD and that weights remain semi-orthogonal after update.  Also demonstrate convergence.___

In [None]:
optim = eOptim.StiefelOpt(W, lr=0.001)
bimap_func = eF.BiMap.apply
spd_arr = []
wne_arr = []
loss_arr = []

def check_spd(mat1):
    # check for spd by checking for positive norms of submatrices X[0:idx, 0:idx] where idx <= rank(X)
    result = True
    for idx in range(mat1.shape[0])
        norm_val = mat1[:idx, :idx].norm()
        if norm_val <= 0.0
            result = False
    return result

# Loop through weight updates and check parameters and outputs
for epoch in range(20):
    W_old = W.clone().detach()
    output = bimap_func(X, W)
    loss = ((output - expect).norm() ** 2)/(s2 ** 2)
    loss_arr.append(loss.item())
    loss.backward()
    optim.step()
    spd_arr.append(check_spd(output))
    wne_arr.append((W == W_old).sum() == 0)

# Ensure Parameters actually updated
assert(torch.tensor(wne_arr).all())
# Ensure Output is SPD
assert(torch.tensor(spd_arr).all())

# Plot Loss over time
plt.plot(loss_arr)
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE)")
    
    

