In [3]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from data import DatasetConfig, WeatherDataset, DummyDataset, DummyDatasetConfig
from model import HeirarchicalTransformer, HeirarchicalTransformerConfig

In [4]:
dataConf = DatasetConfig(
    maxlen = 120,
    hdf_fpath = "notebooks/weatherGiga3.hdf5",
    index = "notebooks/train_idx.txt",
    seqlen = 24
)
wd = WeatherDataset(dataConf)
dataConf

---- DATASET CONFIGURATION ----
+-----------+-----------------------------+
| key       | value                       |
|-----------+-----------------------------|
| index     | notebooks/train_idx.txt     |
| hdf_fpath | notebooks/weatherGiga3.hdf5 |
| maxlen    | 120                         |
| seqlen    | 24                          |
+-----------+-----------------------------+

In [5]:
print("Locations", wd.loc_mat.size())
print("edge_matrix", wd.edge_mat.size())
print("------")
x = wd[45]
for k,v in x.items():
    print(k, "-->", v.size())

Locations torch.Size([484, 3])
edge_matrix torch.Size([484, 484])
------
input --> torch.Size([5, 24, 484, 17])
node_mask --> torch.Size([5, 24, 484])
month_ids --> torch.Size([5, 24])
day_ids --> torch.Size([5, 24])
hour_ids --> torch.Size([5, 24])


In [14]:
modelConf = HeirarchicalTransformerConfig(
    n_embd=17 * 3 * 2,
    n_global= int(17 * 3 * 2),
    maxlen=dataConf.seqlen,
    n_head=3,
    n_layer=1,
    num_nodes=wd.edge_mat.shape[0],
    num_features=17,
    location_features=3,
    mem_len = dataConf.maxlen
)
model = HeirarchicalTransformer(modelConf)
model.num_params

416415

In [15]:
# for x in model.named_parameters():
#     print(x[0], "--->", x[1].shape)

In [16]:
for x in DataLoader(wd, batch_size = 10, shuffle = True):
    mems = None # reset for every batch
    for i in range(x["input"].shape[1]):
        in_data = {k:v[:,i,...] for k,v in x.items()}
        logits, mems, loss = model(
            **in_data,
            mems = mems,
            edge_matrix = wd.edge_mat,
            locations = wd.loc_mat,
            get_loss = True
        )
        print(logits.shape, loss)
    break

torch.Size([10, 24, 484, 17]) tensor(12.2898, grad_fn=<MeanBackward0>)
torch.Size([10, 24, 484, 17]) tensor(12.2777, grad_fn=<MeanBackward0>)
torch.Size([10, 24, 484, 17]) tensor(12.4691, grad_fn=<MeanBackward0>)
torch.Size([10, 24, 484, 17]) tensor(12.4644, grad_fn=<MeanBackward0>)
torch.Size([10, 24, 484, 17]) tensor(12.4119, grad_fn=<MeanBackward0>)


In [13]:
x = torch.rand(12)
print(x)
F.layer_norm(x, x.size())

tensor([0.3766, 0.4105, 0.2262, 0.4773, 0.4744, 0.7812, 0.2691, 0.4900, 0.3605,
        0.3690, 0.0181, 0.9267])


tensor([-0.2409, -0.0925, -0.8988,  0.1997,  0.1873,  1.5295, -0.7112,  0.2553,
        -0.3112, -0.2738, -1.8094,  2.1661])

In [5]:
np.log(-9)

  np.log(-9)


nan

In [3]:
np.nan_to_num(np.nan)

0.0

In [5]:
dummyDataConf = DummyDatasetConfig()
dummyData = DummyDataset(dummyDataConf)

In [6]:
for k,v in dummyData[0].items():
    print(k, "-->", v.size())

input --> torch.Size([2, 5, 30, 5])
node_mask --> torch.Size([2, 5, 30])
month_ids --> torch.Size([2, 5])
day_ids --> torch.Size([2, 5])
hour_ids --> torch.Size([2, 5])
