In [1]:
import torch as th
import torch.nn as nn

from torch_gfrft import EigvalSortStrategy
from torch_gfrft.gfrft import GFRFT
from torch_gfrft.gft import GFT
from torch_gfrft.layer import GFRFTLayer

NUM_NODES = 100
TIME_LENGTH = 200
if th.cuda.is_available():
    DEVICE = th.device('cuda')
else:
    DEVICE = th.device('cpu')

In [2]:
A = th.rand(NUM_NODES, NUM_NODES, device=DEVICE)
# A = A + A.T
A = A - th.diag(th.diag(A))

In [3]:
gft = GFT(A, EigvalSortStrategy.TOTAL_VARIATION)
gfrft = GFRFT(gft.gft_mtx)

In [4]:
original_order = 0.35
th.manual_seed(0)
X = th.randn(NUM_NODES, TIME_LENGTH, device=DEVICE)
Y = gfrft.gfrft(X, original_order, dim=0)

In [5]:
def mse_loss(predictions: th.Tensor, targets: th.Tensor) -> th.Tensor:
    return th.norm(predictions - targets, p='fro', dim=0).mean()

model = nn.Sequential(
    GFRFTLayer(gfrft, 1.0, dim=0),
    GFRFTLayer(gfrft, 0.0, dim=0),
)
print(model)
optim = th.optim.Adam(model.parameters(), lr=1e-3)
epochs = 1000

th.manual_seed(0)
for epoch in range(epochs + 1):
    optim.zero_grad()
    output = mse_loss(model(X), Y)
    if epoch % 100 == 0:
        print(f"Epoch {epoch:4d} | Loss {output.item():<4.4f} | a1 = {model[0].order.item():.4f} | a2 = {model[1].order.item():.4f}")
    output.backward()
    optim.step()
print(f"Original a: {original_order:.4f}, Final a1: {model[0].order.item():.4f} | Final a2: {model[1].order.item():.4f}")
print(f"Final sum: {model[0].order.item() + model[1].order.item():.4f}")

Sequential(
  (0): GFRFT(order=1.0, size=100, dim=0)
  (1): GFRFT(order=0.0, size=100, dim=0)
)
Epoch    0 | Loss 58.1500 | a1 = 1.0000 | a2 = 0.0000
Epoch  100 | Loss 35.9139 | a1 = 0.9078 | a2 = -0.0922
Epoch  200 | Loss 22.7625 | a1 = 0.8299 | a2 = -0.1701
Epoch  300 | Loss 11.5617 | a1 = 0.7540 | a2 = -0.2460
Epoch  400 | Loss 0.1966 | a1 = 0.6737 | a2 = -0.3263
Epoch  500 | Loss 0.0130 | a1 = 0.6749 | a2 = -0.3251
Epoch  600 | Loss 0.0199 | a1 = 0.6749 | a2 = -0.3251
Epoch  700 | Loss 0.0022 | a1 = 0.6750 | a2 = -0.3250
Epoch  800 | Loss 0.0166 | a1 = 0.6751 | a2 = -0.3249
Epoch  900 | Loss 0.0319 | a1 = 0.6752 | a2 = -0.3248
Epoch 1000 | Loss 0.0047 | a1 = 0.6750 | a2 = -0.3250
Original a: 0.3500, Final a1: 0.6750 | Final a2: -0.3250
Final sum: 0.3499
