In [126]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import xarray as xr

import numpy as np

from einops import rearrange, repeat

In [127]:
class Encoder(nn.Module):
    def __init__(self, n_spatial, n_spatio_temporal, dropout):
        super().__init__()
        self.act = nn.ReLU()
        self.dropout = dropout
        
        self.input_fc = nn.Linear(n_spatial, 100)
        self.input_lstm = nn.LSTM(n_spatio_temporal, 300, num_layers=2, batch_first=True, dropout=self.dropout)
        
        self.fc2 = nn.Linear(400, 200)
        
    def forward(self, x_spatial, x_spatio_temporal, y_spatio_temporal):  # x_spatial: (batch_size, n_context_stations, n_spatial), x_spatio_temporal: (batch_size, n_context_stations, window_size, n_spatio_temporal), y_spatio_temporal: (batch_size, n_context_stations, window_size, 1)
        batch_size = x_spatial.shape[0]
        xy_spatio_temporal = torch.cat([x_spatio_temporal, y_spatio_temporal], dim=-1)
        
        z_spatial = torch.vmap(lambda x: F.dropout(self.act(self.input_fc(x)), p=self.dropout), randomness="same")(x_spatial)
        
        xy_spatio_temporal = rearrange(xy_spatio_temporal, 'batch_size n_context_stations window_size n_spatio_temporal -> (batch_size n_context_stations) window_size n_spatio_temporal')
        z_spatio_temporal, _ = self.input_lstm(xy_spatio_temporal)
        z_spatio_temporal = rearrange(z_spatio_temporal[:, -1, :], '(batch_size n_context_stations) lstm_out -> batch_size n_context_stations lstm_out', batch_size=batch_size)
        
        z_concat = torch.cat([z_spatial, z_spatio_temporal], dim=-1)
        z_concat = torch.vmap(lambda x: F.dropout(self.act(self.fc2(x)), p=self.dropout), randomness="same")(z_concat)
        
        return z_concat
    
class decoder(nn.Module):
    def __init__(self, n_spatial, n_spatio_temporal, dropout):
        super().__init__()
        self.act = nn.ReLU()
        self.dropout = dropout
        
        self.input_fc = nn.Linear(n_spatial, 100)
        self.input_lstm = nn.LSTM(n_spatio_temporal, 300, num_layers=2, batch_first=True, dropout=self.dropout)
        
        self.fc2 = nn.Linear(400, 200)
        
    def forward(self, x_spatial, x_spatio_temporal):  # x_spatial: (batch_size, n_target_stations, n_spatial), x_spatio_temporal: (batch_size, n_target_stations, window_size, n_spatio_temporal), y_spatio_temporal: (batch_size, n_target_stations, window_size, 1)
        # We don't use vmap out of the box because it doesn't support the LSTM layer.
        batch_size = x_spatial.shape[0]
        
        z_spatial = torch.vmap(lambda x: F.dropout(self.act(self.input_fc(x)), p=self.dropout), randomness="same")(x_spatial)
        
        x_spatio_temporal = rearrange(x_spatio_temporal, 'batch_size n_target_stations window_size n_spatio_temporal -> (batch_size n_target_stations) window_size n_spatio_temporal')
        z_spatio_temporal, _ = self.input_lstm(x_spatio_temporal)
        z_spatio_temporal = rearrange(z_spatio_temporal[:, -1, :], '(batch_size n_target_stations) lstm_out -> batch_size n_target_stations lstm_out', batch_size=batch_size)
        
        z_concat = torch.cat([z_spatial, z_spatio_temporal], dim=-1)
        z_concat = torch.vmap(lambda x: F.dropout(self.act(self.fc2(x)), p=self.dropout), randomness="same")(z_concat)
        
        return z_concat

class AttentionNet(nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.dropout = dropout
        self.fc = nn.Linear(400, 200)
        self.fc2 = nn.Linear(200, 1)
        
    def forward(self, z_context, z_target):  # z_context: (batch_size, num_context_stations, 200), z_target: (batch_size, num_target_stations, 200)
        num_context_stations = z_context.shape[1]
        
        def single_forward(z_target, z_context):
            z_target = repeat(z_target, "z_dim -> num_content_stations z_dim", num_content_stations=num_context_stations)
            z = torch.cat([z_target, z_context], dim=-1) # (num_context_stations, 200) + (num_context_stations, 200) = (num_context_stations, 400)
            z = F.dropout(F.relu(self.fc(z)), p=self.dropout) # (num_context_stations, 400) -> (num_context_stations, 200)
            z = self.fc2(z) # (num_context_stations, 200) -> (num_context_stations, 1)
            z = F.softmax(z, dim=0)  # (num_context_stations, 1)
            return rearrange(z, 'num_context_stations 1 -> num_context_stations')
        
        # (num_context_stations, 200) + (num_target_stations, 200) -> (num_context_stations, num_target_stations)
        multi_forward = torch.vmap(single_forward, in_dims=(0, None), out_dims=1, randomness="same")
        # (batch_size, num_context_stations, 200) + (batch_size, num_target_stations, 200) -> (batch_size, num_context_stations, num_target_stations)
        attention = torch.vmap(multi_forward, randomness="same")(z_target, z_context)
        return attention
    
class ADAIN(nn.Module):
    def __init__(self, n_spatial, n_spatio_temporal, dropout):
        super().__init__()
        self.encoder = Encoder(n_spatial, n_spatio_temporal, dropout)
        self.decoder = decoder(n_spatial, n_spatio_temporal, dropout)
        self.attention = AttentionNet(dropout)
        self.fc = nn.Linear(200, 1)
        
    def forward(self, x_context_spatial, x_context_spatio_temporal, y_context_spatio_temporal, x_target_spatial, x_target_spatio_temporal):
        z_context = self.encoder(x_context_spatial, x_context_spatio_temporal, y_context_spatio_temporal) # (batch_size, num_context_stations, 200)
        z_target = self.decoder(x_target_spatial, x_target_spatio_temporal) # (batch_size, num_target_stations, 200)
        attention = self.attention(z_context, z_target) # (batch_size, num_context_stations, num_target_stations)
        def get_output(attention):
            attention = rearrange(attention, 'batch_size num_context_stations -> batch_size num_context_stations 1')
            output =  attention * z_context # (batch_size, num_context_stations, 200)
            output = torch.sum(output, dim=1) # (batch_size, 200)
            output = self.fc(output) # (batch_size, 1)
            return output
            
        # (batch_size, num_context_stations, num_target_stations) 
        output = torch.vmap(get_output, in_dims=2, out_dims=1)(attention)
        print(output.shape)
        return output
        
batch_size = 9 # over time
n_spatial = 7
n_spatio_temporal = 11
n_context_stations = 5
n_target_stations = 13
window_size = 24

x_context_spatial = torch.randn(batch_size, n_context_stations, n_spatial)
x_context_spatio_temporal = torch.randn(batch_size, n_context_stations, window_size, n_spatio_temporal)
y_context_spatio_temporal = torch.randn(batch_size, n_context_stations, window_size, 1)
x_target_spatial = torch.randn(batch_size, n_target_stations, n_spatial)
x_target_spatio_temporal = torch.randn(batch_size, n_target_stations, window_size, n_spatio_temporal)

model = ADAIN(n_spatial, n_spatio_temporal, 0.01)
model(x_context_spatial, x_context_spatio_temporal, y_context_spatio_temporal, x_target_spatial, x_target_spatio_temporal).shape

torch.Size([9, 13, 1])


torch.Size([9, 13, 1])

### Creation of dataset

In [128]:
# import namedtuple
from collections import namedtuple

config = {"features": ["lat", "lon", "AT", "BP", "SR"], "target": "PM2.5", "context_fraction": 0.5}
config = namedtuple("Config", config.keys())(*config.values())
static_features = ["lat", "lon"]

class CustomDataset(Dataset):
    def __init__(self, ds, window_size):
        self.ds = ds
        self.window_size = window_size
        self.ts = self.ds.time.values

    def __len__(self):
        return len(self.ts) - self.window_size

    def __getitem__(self, idx):
        t_past = self.ts[idx:idx + self.window_size]
        t = self.ts[idx + self.window_size]
        given_static_features = [fet for fet in config.features if fet in static_features]
        X_static = self.ds.sel(time=t).to_dataframe()[given_static_features].values
        X_dynamic = np.concatenate([self.ds.sel(time=t_past)[var].values.T[..., None] for var in config.features if var not in static_features], axis=-1)
        y_dynamic = self.ds.sel(time=t_past)[config.target].values.T[..., None]

        X_static = torch.tensor(X_static, dtype=torch.float32)
        X_dynamic = torch.tensor(X_dynamic, dtype=torch.float32)
        y_dynamic = torch.tensor(y_dynamic, dtype=torch.float32)
        print(X_static.shape, X_dynamic.shape, y_dynamic.shape)
        
        idx = np.random.permutation(len(X_static))
        num_context = int(config.context_fraction * len(X_static))
        X_context_spatial = X_static[idx[:num_context]]
        X_context_spatio_temporal = X_dynamic[idx[:num_context]]
        y_context_spatio_temporal = y_dynamic[idx[:num_context]]
        
        print(X_context_spatial.shape, X_context_spatio_temporal.shape, y_context_spatio_temporal.shape)
        
        # y_target = y[idx[num_context:]]
        # return X_context, y_context, X_target, y_target

In [129]:
with xr.open_dataset('/home/patel_zeel/aqmsp/aqmsp_data/datasets/cpcb/ijcai24/data.nc') as ds:
    pass
ds

In [130]:
# ds.sel(time="2022")

In [131]:
dataset = CustomDataset(ds, 24)
for i in range(len(dataset)):
    dataset[i]
    break

torch.Size([47, 2]) torch.Size([47, 24, 3]) torch.Size([47, 24, 1])
torch.Size([23, 2]) torch.Size([23, 24, 3]) torch.Size([23, 24, 1])
