# KRNO usage

In [1]:
import torch
from torch import nn

from khatriraonop import models, quadrature

## 1. Temporal Problems

In [2]:
# ---------------------------------------------------------------------------------
# setting up the model grid
# dimensionality of the input domain
d = 1 # [t]

# Helper method to initialize 
# KRNO model with default configuration 
model = models.KhatriRaoNO_v2.easy_init(
    d,  # dimensionality of the input domain
    in_channels=2,          #  #input channels 
    out_channels=2,         #  #output channels 
    lifting_channels=128,   #  #lifting channels
    integral_channels=20,   #  #channels in each integral layer
    n_integral_layers=3,    #  #KRNO integral layers
    projection_channels=128,#  #projection channels
    n_hidden_units=32,      #  #hidden units in each layer of neural network parametrizing component-wise kernel
    n_hidden_layers=3,      # #hidden layers in neural network parametrizing component-wise kernel
    nonlinearity=nn.SiLU(), # Activation function
)
# print(model)

In [3]:
# 5 time instances in the past time-window  [0, 0.5)
past_times    = torch.tensor([0.03, 0.12, 0.18, 0.31, 0.45])

# predict at 7 time instances in the future time-window  [0.5, 1]
predict_times = torch.tensor([0.53, 0.67, 0.74, 0.79, 0.86, 0.9, 0.98])

In [4]:
# generating some dummy input data
batch_size = 8

u = torch.randn(batch_size, 5, 2)  # (BS, N, C), N is #time-steps, C is #channels
print('input shape:', u.shape)

input shape: torch.Size([8, 5, 2])


In [5]:
# Compute input and output quadrature grids based on past_times and predict_times
quad_grid_in  = quadrature.trapezoidal_vecs_uneven(past_times)
quad_grid_out = quadrature.trapezoidal_vecs_uneven(predict_times)
in_grid       = ([quad_grid_in[0]], [quad_grid_in[1]]) 
out_grid      = ([quad_grid_out[0]], [quad_grid_out[1]])

In [6]:
# transform u -> v
v = model.super_resolution(out_grid, in_grid, u)
print('ouput shape:', v.shape)

ouput shape: torch.Size([8, 7, 2])


## 2. Spatio-temporal problems 

#### 2D spatio-temporal problem 

**Considering shallow water problem**

In [7]:
# 2D spatio-temporal problem 

# example: shallow water problem
# dimensionality of the input domain
d = 3 # [t, x, y]
Sx = 32 # x resolution
Sy = 32 # y resolution
Nc = 3 #\rho, u, v
lag = 5

# Helper method to initialize 
# KRNO model with default configuration 
model = models.KhatriRaoNO_v2.easy_init(
    d,  # dimensionality of the input domain
    in_channels=Nc,          #  \rho, u, v
    out_channels=Nc,         #  \rho, u, v
    lifting_channels=128,   #  #lifting channels
    integral_channels=20,   #  #channels in each integral layer
    n_integral_layers=3,    #  #KRNO integral layers
    projection_channels=128,#  #projection channels
    n_hidden_units=32,      #  #hidden units in each layer of neural network parametrizing component-wise kernel
    n_hidden_layers=3,      # #hidden layers in neural network parametrizing component-wise kernel
    nonlinearity=nn.SiLU(), # Activation function
)
# print(model)

In [8]:
# get computational grid
quad_fns = [
    quadrature.midpoint_vecs,
    quadrature.trapezoidal_vecs,
    quadrature.trapezoidal_vecs,
]

## Input grid 
in_grid = quadrature.get_quad_grid(
    quad_fns, [lag, Sx, Sy], [-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]
)

## Output grid 
out_grid = quadrature.get_quad_grid(
    quad_fns, [lag, Sx, Sy], [-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]
)

In [9]:
# generating some dummy input data
batch_size = 2

u = torch.randn(batch_size, lag, Sx, Sy, Nc)
print('input shape:', u.shape)


input shape: torch.Size([2, 5, 32, 32, 3])


**Using GPU**

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

model = model.to(device)
u     = u.to(device)

out_grid = quadrature.quad_grid_to_device(out_grid, device)
in_grid  = quadrature.quad_grid_to_device(in_grid, device)


Device: cuda


In [11]:
# transform u -> v
v = model.super_resolution(out_grid, in_grid, u)
print('ouput shape:', v.shape)

ouput shape: torch.Size([2, 5, 32, 32, 3])


### Using lower-resolution quadrature grid in the intermediate integral layers

In [12]:
quad_grid_latent = quadrature.get_quad_grid(
            quad_fns, [lag, int(Sx/2), int(Sy/2)], [-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]
        )
latent_grid = quadrature.quad_grid_to_device(quad_grid_latent, device)


In [13]:
v = model.super_resolution(out_grid, in_grid, u, latent_grid=latent_grid)
print('ouput shape:', v.shape)

ouput shape: torch.Size([2, 5, 32, 32, 3])


### Forecasting at super-resolution in both space and time

In [14]:
## Output grid 
out_grid = quadrature.get_quad_grid(
    quad_fns, [4*lag, 4*Sx, 4*Sy], [-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]
)
out_grid = quadrature.quad_grid_to_device(out_grid, device)

In [15]:
v = model.super_resolution(out_grid, in_grid, u, latent_grid=latent_grid)
print('ouput shape:', v.shape)

ouput shape: torch.Size([2, 20, 128, 128, 3])


When `in_grid`,`latent_grid` and `out_grid` are same for training and inference, we can improve the performance of model by using the following method, which enables the affine maps in the first and last integral layers.

In [None]:
# Helper method to initialize 
# KRNO model 
model = models.KhatriRaoNO_v2.easy_init(
    d,  # dimensionality of the input domain
    in_channels=Nc,          #  \rho, u, v
    out_channels=Nc,         #  \rho, u, v
    lifting_channels=128,   #  #lifting channels
    integral_channels=20,   #  #channels in each integral layer
    n_integral_layers=3,    #  #KRNO integral layers
    projection_channels=128,#  #projection channels
    n_hidden_units=32,      #  #hidden units in each layer of neural network parametrizing component-wise kernel
    n_hidden_layers=3,      # #hidden layers in neural network parametrizing component-wise kernel
    nonlinearity=nn.SiLU(), # Activation function
)
model = model.to(device)

# when input and output grids are same 
cart_grid = quadrature.quad_to_cartesian_grid(in_grid)
cart_grid = quadrature.cart_grid_to_device(cart_grid, device) # move grid to device

v1         = model(cart_grid, u)
print('ouput shape:', v1.shape)

ouput shape: torch.Size([2, 5, 32, 32, 3])
