In [1]:
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
import tqdm
torch.set_default_dtype(torch.float64)

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda")

In [3]:
data1 = torch.tensor(np.load("mu_0.9.npy"))
data2 = torch.tensor(np.load("mu_0.95.npy"))
data3 = torch.tensor(np.load("mu_1.05.npy"))
data4 = torch.tensor(np.load("mu_1.1.npy"))
X = torch.cat([data1, data2, data3, data4], axis=0)
(S, N) = X.shape

X_test = torch.tensor(np.load("mu_1.0.npy"))

In [4]:
x_ref = torch.mean(X, dim=0).to(device)

In [5]:
class NRBS(torch.nn.Module):
    def __init__(self, N, n, M1, M2, b):
        super(NRBS, self).__init__()

        mask = torch.zeros(N, M2)
        shift = (M2 - b) / (N - 1)

        for i in range(N):
            mask[i, int(np.ceil(shift*i)): int(np.ceil(shift*i)) + b] = 1

        new_mask = torch.zeros(N, M2)

        for idx in range(N):
            i = idx // 60
            j = idx % 60
            neighbours = [(i, j), (i-1, j), (i+1, j), (i, j+1), (i, j-1)]
            for neighbour_i, neighbour_j in neighbours:
              if (neighbour_i >=0 and neighbour_i < 60) and (neighbour_j >=0 and neighbour_j < 60):
                new_mask[idx] = new_mask[idx] + mask[60*neighbour_i + neighbour_j]

        new_mask[new_mask > 0] = 1

        self.register_buffer('mask', new_mask)


        self.encoder1 = torch.nn.Linear(N, M1)
        self.encoder2 = torch.nn.Linear(M1, n)

        self.decoder1 = torch.nn.Linear(n, M2)
        self.decoder2 = torch.nn.Linear(M2, N)

        torch.nn.init.kaiming_normal_(self.encoder1.weight)
        torch.nn.init.kaiming_normal_(self.encoder2.weight)
        torch.nn.init.kaiming_normal_(self.decoder1.weight)
        torch.nn.init.kaiming_normal_(self.decoder2.weight)

    def encode(self, x):
        x = self.encoder1(x)
        x = x * torch.sigmoid(x)
        x = self.encoder2(x)
        return x

    def decode(self, x):
        x = self.decoder1(x)
        x = x * torch.sigmoid(x)
        x = torch.matmul(x, (self.decoder2.weight * self.get_buffer('mask')).T) + self.decoder2.bias
        return x

    def forward(self, x):
        return self.decode(self.encode(x))


In [6]:
n = 20
lr = 1e-3
epochs = 10000
B = 240
lr_red_factor = 0.1
patience = 10
l1_reg = 0

nrbs = NRBS(N, n, 6728, 33730, 70).to(device)

In [7]:
# mask = nrbs.get_buffer('mask').cpu().detach().numpy()
# fig, ax = plt.subplots(figsize=(10, 10))
# ax.spy(mask, markersize=5, aspect='auto')

In [8]:
X = X.to(device)
X_test = X_test.to(device)
dataloader = torch.utils.data.DataLoader(X, batch_size=B, shuffle=True)
optimizer = torch.optim.Adam(nrbs.parameters(), lr=lr, weight_decay=l1_reg)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, "min", factor=lr_red_factor, patience=patience
)

# loss_func = torch.nn.MSELoss()

# writer = SummaryWriter()
# for epoch in range(epochs):
#     for x in tqdm.tqdm(dataloader):
#         optimizer.zero_grad()
#         x_tilde = nrbs(x - x_ref) + x_ref
#         l = torch.sqrt(torch.sum((x_tilde - x) ** 2))
#         l.backward()
#         optimizer.step()

#     with torch.no_grad():
#         X_tilde = nrbs(X - x_ref) + x_ref
#         l_train = torch.sqrt(torch.sum((X - X_tilde) ** 2)) / torch.sqrt(
#             torch.sum(X**2)
#         )

#         l_train_mse = loss_func(X, X_tilde).item()

#         X_tilde = nrbs(X_test - x_ref) + x_ref
#         l_test = torch.sqrt(torch.sum((X_test - X_tilde) ** 2)) / torch.sqrt(
#             torch.sum(X_test**2)
#         )
        

#     scheduler.step(l_train)

#     writer.add_scalar("loss/train", l_train, epoch)
#     writer.add_scalar("loss/test", l_test, epoch)
#     writer.add_scalar("loss/mse_train", l_train_mse, epoch)
    
#     writer.add_scalar("lr", optimizer.param_groups[0]["lr"], epoch)
#     print("epoch: {:}".format(epoch))
#     print("loss/test: {:}".format(l_test))
#     print("loss/train: {:}".format(l_train))
#     print("loss/mse_train: {:}".format(l_train_mse))
#     print("lr: {:}".format(optimizer.param_groups[0]["lr"]))


In [9]:
torch.save(nrbs, 'models/shallow_mask.pth')

In [10]:
test = torch.load('models/shallow_mask.pth')
X_tilde = test(X_test - x_ref) + x_ref
l_test = torch.sqrt(torch.sum((X_test - X_tilde) ** 2)) / torch.sqrt(
    torch.sum(X_test**2)
)
l_test


tensor(0.5572, device='cuda:0', grad_fn=<DivBackward0>)

In [11]:
X_tilde = nrbs(X_test - x_ref) + x_ref
l_test = torch.sqrt(torch.sum((X_test - X_tilde) ** 2)) / torch.sqrt(
    torch.sum(X_test**2)
)
l_test

tensor(0.5572, device='cuda:0', grad_fn=<DivBackward0>)