In [1]:
import torch as th
import torch.nn as nn

from torch_gfrt import EigvalSortStrategy
from torch_gfrt.gfrft import GFRT
from torch_gfrt.gft import GFT
from torch_gfrt.layer import GFRTLayer

NUM_NODES = 100
TIME_LENGTH = 200
if th.cuda.is_available():
    DEVICE = th.device('cuda')
else:
    DEVICE = th.device('cpu')

In [2]:
A = th.rand(NUM_NODES, NUM_NODES, device=DEVICE)
# A = A + A.T
A = A - th.diag(th.diag(A))

In [3]:
gft = GFT(A, EigvalSortStrategy.TOTAL_VARIATION)
gfrt = GFRT(gft.gft_mtx)

In [4]:
original_order = 0.35
th.manual_seed(0)
X = th.randn(NUM_NODES, TIME_LENGTH, device=DEVICE)
Y = gfrt.gfrt(X, original_order, dim=0)

In [5]:
def mse_loss(predictions: th.Tensor, targets: th.Tensor) -> th.Tensor:
    return th.norm(predictions - targets, p='fro', dim=0).mean()

model = nn.Sequential(
    GFRTLayer(gfrt, 1.0, dim=0),
    GFRTLayer(gfrt, 0.0, dim=0),
)
print(model)
optim = th.optim.Adam(model.parameters(), lr=1e-3)
epochs = 2000

for epoch in range(epochs + 1):
    optim.zero_grad()
    output = mse_loss(model(X), Y)
    if epoch % 100 == 0:
        print(f"Epoch {epoch:4d} | Loss {output.item():<4.4f} | a1 = {model[0].order.item():.4f} | a2 = {model[1].order.item():.4f}")
    output.backward()
    optim.step()
print(f"Original a: {original_order:.4f}, Final a1: {model[0].order.item():.4f} | Final a2: {model[1].order.item():.4f}")
print(f"Final sum: {model[0].order.item() + model[1].order.item():.4f}")

Sequential(
  (0): GFRT(order=1.0, size=100, dim=0)
  (1): GFRT(order=0.0, size=100, dim=0)
)
Epoch    0 | Loss 59.7111 | a1 = 1.0000 | a2 = 0.0000
Epoch  100 | Loss 37.1993 | a1 = 0.9081 | a2 = -0.0919
Epoch  200 | Loss 23.7966 | a1 = 0.8302 | a2 = -0.1698
Epoch  300 | Loss 11.8672 | a1 = 0.7527 | a2 = -0.2473
Epoch  400 | Loss 0.3659 | a1 = 0.6726 | a2 = -0.3274
Epoch  500 | Loss 0.0026 | a1 = 0.6750 | a2 = -0.3250
Epoch  600 | Loss 0.0160 | a1 = 0.6751 | a2 = -0.3249
Epoch  700 | Loss 0.0237 | a1 = 0.6748 | a2 = -0.3252
Epoch  800 | Loss 0.0028 | a1 = 0.6750 | a2 = -0.3250
Epoch  900 | Loss 0.0167 | a1 = 0.6751 | a2 = -0.3249
Epoch 1000 | Loss 0.0246 | a1 = 0.6748 | a2 = -0.3252
Epoch 1100 | Loss 0.0029 | a1 = 0.6750 | a2 = -0.3250
Epoch 1200 | Loss 0.0171 | a1 = 0.6751 | a2 = -0.3249
Epoch 1300 | Loss 0.0251 | a1 = 0.6748 | a2 = -0.3252
Epoch 1400 | Loss 0.0029 | a1 = 0.6750 | a2 = -0.3250
Epoch 1500 | Loss 0.0174 | a1 = 0.6751 | a2 = -0.3249
Epoch 1600 | Loss 0.0254 | a1 = 0.6748 