In [None]:
# default_exp model

In [None]:
#hide
from nbdev.showdoc import show_doc

# MetNet Model

> Implementation of the parts of the metnet arch from the paper

Here we are going to implement the parts of the model [MetNet](https://arxiv.org/abs/2003.12140) from "MetNet: A Neural Weather Model for Precipitation Forecasting"

![metenet_scheme](images/metnet_scheme.png)

In [None]:
#export
from metnet_pytorch.layers import *
from fastai.vision.all import *

## Downsampler

from the paper, the downsampler blocks are a bunch of convs and maxpooling layers, with out anything fancy, not even activations. From the paper: 
> MetNet aims at fully capturing the spatial context in the input patch. A trade-off arises between the
fidelity of the representation and the memory and computation required to compute it. To maintain
viable memory and computation requirements, the first part of MetNet contracts the input tensor
spatially using a series of convolution and pooling layers. The t slices along the time dimension of
the input patch are processed separately. Each slice is first packaged into an input tensor of spatial
dimensions 256 × 256 (see Appendix A for the exact pre-processing operations). Each slice is then
processed by the following neural network layers: a 3 × 3 convolution with 160 channels, a 2 × 2
max-pooling layer with stride 2, three more 3 × 3 convolutions with 256 channels and one more
2 × 2 max pooling layer with stride 2. These operations produce t tensors of spatial dimensions
64 × 64 and 256 channels.

In [None]:
#export
def DownSampler(in_channels):
    return nn.Sequential(nn.Conv2d(in_channels, 160, 3, padding=1),
                         nn.MaxPool2d((2,2), stride=2),
                         nn.BatchNorm2d(160),
                         nn.Conv2d(160, 256, 3, padding=1),
                         nn.BatchNorm2d(256),
                         nn.Conv2d(256, 256, 3, padding=1),
                         nn.BatchNorm2d(256),
                         nn.Conv2d(256, 256, 3, padding=1),
                         nn.MaxPool2d((2,2), stride=2)
                         )

I put less convs and added `nn.BatchNorm2d`, as I finally ended up using another image encoder, you can choose anything form torchvision or [timm](https://github.com/rwightman/pytorch-image-models)

In [None]:
ds = DownSampler(3)
test_eq(ds(torch.rand(2, 3, 256, 256)).shape,[2, 256, 64, 64])

as we can check, it divides by four the spatial resolution, 

## Temporal Encoder

The second part of MetNet encodes the input patch along the temporal dimension.  The spatially contracted slices are given to a recurrent neural network following the order of time.   We use a Convolutional Long Short-Term Memory network with kernel size 3×3 and 384 channels for the temporal encoding   (Xingjian et al., 2015). 

The result is a single tensor of size 64×64 and 384 channels, where each location summarizes spatially and temporally one region of the large contextin the input patch

In [None]:
#export
class TemporalEncoder(Module):
    def __init__(self, in_channels, out_channels=384, ks=3, n_layers=1):
        self.rnn = ConvGRU(in_channels, out_channels, (ks, ks), n_layers, batch_first=True)
    def forward(self, x):
        x, h = self.rnn(x)
        return (x, h[-1])

In [None]:
te = TemporalEncoder(4, 8, n_layers=1)
x,h = te(torch.rand(2, 10, 4, 12, 12))
test_eq(h.shape, [2,8,12,12])
test_eq(x.shape, [2,10,8,12,12])

## Conditioning on Target Lead Time


The leadtime is represented as an integeri= (Ty/2)−1indicating minutes from 2 to 480.  The integeriis tiled along thew×hlocations in the patch and is represented as an all-zero vector with a 1at positioniin the vector.  By changing the target lead time given as input, one can use the sameMetNet model to make forecasts for the entire range of target times that MetNet is trained on

In [None]:
seq_len=5
i=3
times = (torch.eye(seq_len)[i-1]).float().unsqueeze(-1).unsqueeze(-1)

In [None]:
ones = torch.ones(1,2,2)

In [None]:
times.shape, ones.shape

(torch.Size([5, 1, 1]), torch.Size([1, 2, 2]))

In [None]:
res = times * ones
res

tensor([[[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]],

        [[1., 1.],
         [1., 1.]],

        [[0., 0.],
         [0., 0.]],

        [[0., 0.],
         [0., 0.]]])

In [None]:
#export
def condition_time(x, i=0, size=(12, 16), seq_len=15):
    "create one hot encoded time image-layers, i in [1, seq_len]"
    assert i<seq_len
    times = (torch.eye(seq_len, dtype=x.dtype, device=x.device)[i]).unsqueeze(-1).unsqueeze(-1)
    ones = torch.ones(1,*size, dtype=x.dtype, device=x.device)
    return times * ones

Beware, from `i=0` to `i=seq_len-1`

In [None]:
x = torch.rand(3,5,2,8,8)
i = 13
ct = condition_time(x, i, (12,16), seq_len=15)
assert ct[i, :,:].sum() == 12*16  #full of ones
ct.shape, ct[:, 0,0]

(torch.Size([15, 12, 16]),
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.]))

In [None]:
#export
class ConditionTime(Module):
    "Condition Time on a stack of images, adds `horizon` channels to image"
    def __init__(self, horizon, ch_dim=2): 
        self.horizon = horizon
        self.ch_dim = ch_dim
        
    def forward(self, x, fstep=0):
        "x stack of images, fsteps"
        bs, seq_len, ch, h, w = x.shape
        ct = condition_time(x, fstep, (h,w), seq_len=self.horizon).repeat(bs, seq_len, 1,1,1)
        x = torch.cat([x,ct], dim=self.ch_dim)
        assert x.shape[self.ch_dim] == (ch + self.horizon)  #check if it makes sense
        return x

In [None]:
ct = ConditionTime(3)
x = torch.rand(1,5,2,4,4)
y = ct(x, 1)
y.shape

torch.Size([1, 5, 5, 4, 4])

In [None]:
#export
def feat2image(x, target_size=(128,128)): 
    "This idea comes from MetNet"
    x = x.transpose(1,2)
    return x.unsqueeze(-1).unsqueeze(-1) * x.new_ones(1,1,1,*target_size)

In [None]:
x = torch.rand(2,4,10)
feat2image(x, target_size=(16,16)).shape

torch.Size([2, 10, 4, 16, 16])

## Temporal Aggregator

>To make MetNet’s receptive field cover the full global spatial context in the input patch, the third
part of MetNet uses a series of eight axial self-attention blocks (Ho et al., 2019; Donahue and Si-
monyan, 2019). Four axial self-attention blocks operating along the width and four blocks operating
along the height are interleaved and have 2048 channels and 16 attention heads each

please install using pip:
```bash
pip install axial_attention
```

In [None]:
#export
from axial_attention import AxialAttention

In [None]:
attn = AxialAttention(
    dim = 16,           # embedding dimension
    dim_index = 1,       # where is the embedding dimension
    heads = 8,           # number of heads for multi-head attention
    num_dimensions = 2,  # number of axial dimensions (images is 2, video is 3, or more)
)

In [None]:
x = torch.rand(2, 16, 64, 64)
test_eq(attn(x).shape, x.shape)

## The Model MetNet

We will build a small model to try the concept first.
- The model will output all timesteps up to `horizon`.
- We can condition on time before passing the images or after (saving some computations)
- To start, we will output a timeseries, so we will put a `head` that generates one value per timestep. If you don't put any `head` you get the full attention maps.

In [None]:
#export
class MetNet(Module):
    def __init__(self, image_encoder, hidden_dim, ks=3, n_layers=1, n_att_layers=1,
                 head=None, horizon=3, n_feats=0, p=0.2, debug=False):
        self.horizon = horizon
        self.n_feats = n_feats
        self.drop = nn.Dropout(p)
        nf = 256  #from the simple image encoder
        self.image_encoder = TimeDistributed(image_encoder)
        self.ct = ConditionTime(horizon)
        self.temporal_enc = TemporalEncoder(nf,  hidden_dim, ks=ks, n_layers=n_layers)
        self.temporal_agg = nn.Sequential(*[AxialAttention(dim=hidden_dim, dim_index=1, heads=8, num_dimensions=2) for _ in range(n_att_layers)])
        
        if head is None:
            self.head = Noop()
        else:
            self.head = head
        
        self.debug = debug
        
    def encode_timestep(self, x, fstep=1):
        if self.debug:  print(f'Encode Timestep:(i={fstep})')
        if self.debug:  print(f' input shape: {x.shape}')
        
        #Condition Time
        x = self.ct(x, fstep)
        if self.debug: print(f' CondTime->x.shape: {x.shape}')

        ##CNN
        x = self.image_encoder(x)
        if self.debug:  print(f' encoded images shape: {x.shape}')
        
        #Temporal Encoder
        _, state = self.temporal_enc(self.drop(x))
        if self.debug:  print(f' temp_enc out shape: {state.shape}')
        return self.temporal_agg(state)
        
            
    def forward(self, imgs, feats):
        """It takes a rank 5 tensor 
        - imgs [bs, seq_len, channels, h, w]
        - feats [bs, n_feats, seq_len]"""
        if self.debug:  print(f' Input -> (imgs: {imgs.shape}, feats: {feats.shape})')    
        #stack feature as images
        if self.n_feats>0: 
            feats = feat2image(feats, target_size=imgs.shape[-2:])
            imgs = torch.cat([imgs, feats], dim=2)
        if self.debug:  print(f' augmented imgs:   {imgs.shape}')
        
        #Compute all timesteps, probably can be parallelized
        res = []
        for i in range(self.horizon):
            x_i = self.encode_timestep(imgs, i)
            out = self.head(x_i)
            res.append(out)
        res = torch.stack(res, dim=1).squeeze()
        if self.debug: print(f'{res.shape=}')
        return res

The params are as following:
- `image_encoder`: A image 2 feature model, can be a VGG for instance.
- `hidden_dim`: The channels on the temporal encoder ConvGRU cell.
- `ks`: kernel size on the ConvGRU cell.
- `n_layers`: Number of ConvGRU cells.
- `n_att_layers`: Number of AxialAttention layers on the Temporal Aggregator.
- `ct_first`: If we condition time before or after image encoding.
- `head`: The head output of the model.
- `horizon`: How many timesteps to predict.
- `n_feats`: How many features are we passing to the model besides images, they will be encoded as image layers. See appendix of paper.
- `p`: Dropout on temporal encoder.
- `debug`: If True, prints every intermediary step

The model is structured with a `encode_timestep` method to condition on each timestep the input images:
- First we take the input image sequence and condition on lead time
- We pass this augmented image trhough the image_encoder
- We apply the temporal encoder and 
- Finally we do the spatial attention.

In the forward method:
- We encode the numerical features on image channels using `feat2image`
- We stack these with the original image
- We iteratively call the `encode_timestep` and finally we return the predicted vector

Let's check:

In [None]:
horizon = 5
n_feats = 4

the `image_encoder` must take 3 (RGB image) + horizon (for the conditining time) + feats (for the extra data planes added to image)

In [None]:
image_encoder = DownSampler(3+horizon+n_feats)

In [None]:
metnet = MetNet(image_encoder, hidden_dim=128, 
                ks=3, n_layers=1, horizon=horizon, 
                head=create_head(128, 1), n_feats=n_feats, debug=True)

timeseries data, could be other thing that is sequential as the images

In [None]:
feats = torch.rand(2, n_feats, 10)

In [None]:
imgs = torch.rand(2, 10, 3, 128, 128)

In [None]:
out = metnet(imgs, feats)
out.shape

 Input -> (imgs: torch.Size([2, 10, 3, 128, 128]), feats: torch.Size([2, 4, 10]))
 augmented imgs:   torch.Size([2, 10, 7, 128, 128])
Encode Timestep:(i=0)
 input shape: torch.Size([2, 10, 7, 128, 128])
 CondTime->x.shape: torch.Size([2, 10, 12, 128, 128])
 encoded images shape: torch.Size([2, 10, 256, 32, 32])
 temp_enc out shape: torch.Size([2, 128, 32, 32])
Encode Timestep:(i=1)
 input shape: torch.Size([2, 10, 7, 128, 128])
 CondTime->x.shape: torch.Size([2, 10, 12, 128, 128])
 encoded images shape: torch.Size([2, 10, 256, 32, 32])
 temp_enc out shape: torch.Size([2, 128, 32, 32])
Encode Timestep:(i=2)
 input shape: torch.Size([2, 10, 7, 128, 128])
 CondTime->x.shape: torch.Size([2, 10, 12, 128, 128])
 encoded images shape: torch.Size([2, 10, 256, 32, 32])
 temp_enc out shape: torch.Size([2, 128, 32, 32])
Encode Timestep:(i=3)
 input shape: torch.Size([2, 10, 7, 128, 128])
 CondTime->x.shape: torch.Size([2, 10, 12, 128, 128])
 encoded images shape: torch.Size([2, 10, 256, 32, 32])


torch.Size([2, 5])

In [None]:
#export
def metnet_splitter(m):
    "A simple param splitter for MetNet"
    return [params(m.image_encoder), params(m.te)+params(m.head)]

# Export -

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_layers.ipynb.
Converted 01_model.ipynb.
Converted index.ipynb.
