In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from sklearn.cluster import KMeans
import copy

import sys
sys.path.append("lib")

import lib

import torch
torch.set_default_dtype(torch.float64)

from functorch import vmap
from torch.utils.data import DataLoader
import tqdm

device = torch.device("cuda")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
coords_x, coords_y, _ = lib.utils.read_coords("2d_burger_data/time_step_0.csv")
coords = torch.stack((coords_x, coords_y), dim=1)

In [3]:
# dt = 0.004
# num_steps = 501
dt = 0.002
num_steps = 1001
num_nodes = coords_x.shape[0]

datas = lib.utils.read_data(num_steps, num_nodes, '2d_burger_data/time_step_', 'vel_0')
datas.shape

torch.Size([1001, 14641])

In [4]:
u_dot = lib.utils.u_dot(datas, dt)

datas = torch.hstack((datas[:-1], u_dot))
datas.shape

torch.Size([1000, 29282])

In [5]:
m = 50
clustering = KMeans(n_clusters=m, random_state=0, n_init="auto").fit(coords)

In [6]:
group_indices = []
for i in range(m):
    group_indices.append(torch.tensor(np.where(clustering.labels_ == i)[0], device=device))

In [7]:
N = datas.shape[1] // 2
n = 20
# fixed support (length)
mu = int(np.ceil(N/200))
neighbour_distance, neighbour_id = lib.utils.topk_neighbours(coords, mu)

In [8]:
ed_test = lib.nrbs_test.EncoderDecoder(N=N, n=n, mu=mu, m=m, neighbour_id=neighbour_id, neighbour_distance=neighbour_distance,clustering_labels=torch.tensor(clustering.labels_).type(torch.LongTensor),group_indices=group_indices, device=device)

In [9]:
ed_n_m = lib.nrbs_n_m.EncoderDecoder(N=N, n=n, mu=mu, m=m, neighbour_id=neighbour_id, neighbour_distance=neighbour_distance, clustering_labels=torch.tensor(clustering.labels_).type(torch.LongTensor), device=device)

In [10]:
torch.sum(ed_test.nrbs.decoder.weight == ed_n_m.nrbs.decoder.weight) == N*n

tensor(True, device='cuda:0')

In [11]:
torch.sum(ed_test.nrbs.encoder.weight == ed_n_m.nrbs.encoder.weight) == N*n

tensor(True, device='cuda:0')

In [12]:
torch.sum(ed_test.nrbs.bandwidth_layers.weight == ed_n_m.nrbs.bandwidth_layers.weight) == n*n*m

tensor(True, device='cuda:0')

In [13]:
encoded_test = ed_test.nrbs.encode(datas[999:1000,:N].to(device))
encoded_n_m = ed_n_m.nrbs.encode(datas[999:1000,:N].to(device))

In [14]:
torch.sum(encoded_test == encoded_n_m) == n

tensor(True, device='cuda:0')

In [16]:
bs_test = torch.sigmoid(ed_test.nrbs.bandwidth_layers(encoded_test)) / 60
bs_test = bs_test.reshape(-1, n, m)
bs_test.shape

torch.Size([1, 20, 50])

In [17]:
neighbour_distance.unsqueeze(0).unsqueeze(0).expand(1, n, -1, -1).shape

torch.Size([1, 20, 14641, 74])

In [18]:
decoded_test = ed_test.nrbs.decode(encoded_test)
decoded_n_m = ed_n_m.nrbs.decode(encoded_test)

In [36]:
decoded_test

tensor([[0.0001, 0.0001, 0.0001,  ..., 0.0001, 0.0001, 0.0001]],
       device='cuda:0', grad_fn=<SqueezeBackward1>)

In [38]:
torch.sum(decoded_n_m - decoded_test)

tensor(1.8033e-06, device='cuda:0', grad_fn=<SumBackward0>)

In [20]:
bs_test = torch.sigmoid(ed_test.nrbs.bandwidth_layers(encoded_test)) / 60
bs_test = bs_test.reshape(-1, n, m)

In [21]:
bs_n_m = torch.sigmoid(ed_n_m.nrbs.bandwidth_layers(encoded_n_m[0])) / 60
bs_n_m = bs_n_m.reshape(n, m)

In [22]:
torch.sum(bs_n_m == bs_test.unsqueeze(0)) == n * m

tensor(False, device='cuda:0')

In [23]:
bs_n_m = bs_n_m[:, ed_n_m.nrbs.clustering_labels]
bs_n_m.shape

torch.Size([20, 14641])

In [24]:
convolved_n_m = ed_n_m.nrbs.convolve(
                ed_n_m.nrbs.decoder.weight,
                ed_n_m.nrbs.neighbour_id,
                ed_n_m.nrbs.neighbour_distance,
                bs_n_m,
                ed_n_m.nrbs.mu,
            )

In [25]:
convolved_n_m

tensor([[-5.0700e-04, -7.5262e-05, -3.0856e-04,  ..., -1.9609e-04,
          1.1001e-04,  9.4636e-05],
        [-5.0779e-04, -6.9198e-04, -7.1617e-04,  ..., -7.2528e-04,
         -3.0712e-04, -5.6869e-04],
        [-1.1306e-03, -1.0105e-03, -1.2024e-03,  ..., -7.9567e-04,
         -7.4655e-04, -8.8789e-04],
        ...,
        [ 7.5838e-04,  1.1266e-03,  1.0055e-03,  ...,  2.7907e-04,
          2.5878e-04,  1.2974e-04],
        [ 3.0149e-04,  2.5783e-04,  4.8183e-04,  ..., -3.9197e-04,
         -4.0592e-04, -3.7127e-04],
        [-1.7873e-05,  7.1424e-05,  9.1176e-05,  ..., -7.4330e-04,
         -8.8665e-04, -6.0313e-04]], device='cuda:0', grad_fn=<SumBackward1>)

In [26]:
window = torch.relu(
            -(ed_n_m.nrbs.neighbour_distance.unsqueeze(0).expand(n, -1, -1) ** 2)
            / (bs_n_m.unsqueeze(-1) * mu) ** 2
            + 1
        )

In [27]:
torch.sum(window,dim=2,keepdim=True).shape

torch.Size([20, 14641, 1])