In [1]:
# default_exp model

# The model

> from the paper from [Paletta et al](https://arxiv.org/pdf/2104.12419v1.pdf)

We will try to implement as close as possible the architecture from the paper `ECLIPSE : Envisioning Cloud Induced Perturbations in Solar Energy`

![Image](../images/eclipse_diagram.png)

In [2]:
#export
import torch
import torch.nn as nn
import torch.nn.functional as F

from eclipse_pytorch.layers import *

## 1. Spatial Downsampler
> A resnet encoder to get image features

You could use any spatial downsampler as you want, but the paper states a simple resnet arch...

In [19]:
## export
class SpatialDownsampler(nn.Module):
    
    def __init__(self, in_channels=3):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, 64, kernel_size=7, stride=1)
        self.blocks = nn.Sequential(ResBlock(64, 64, kernel_size=3, stride=2), 
                                    ResBlock(64, 128, kernel_size=3, stride=2), 
                                    ResBlock(128,256, kernel_size=3, stride=2))
    
    def forward(self, x):
        return self.blocks(self.conv1(x))

In [20]:
sd = SpatialDownsampler()

In [21]:
images = [torch.rand(1, 3, 256, 256) for _ in range(4)]
features = torch.stack([sd(image) for image in images], dim=2)
features.shape

torch.Size([1, 256, 4, 32, 32])

## 2. Temporal Encoder

In [22]:
te = TemporalBlock(256, 128)

In [23]:
temp_encoded = te(features)
temp_encoded.shape

torch.Size([1, 128, 4, 32, 32])

## 3. Future State Predictions

In [24]:
fp = FuturePrediction(128, 128, n_gru_blocks=4, n_res_layers=4)

In [25]:
hidden = torch.rand(1, 128, 32, 32)

fp(temp_encoded.permute(0,2,1,3,4), hidden).shape

torch.Size([1, 4, 128, 32, 32])

## 4A. Segmentation Decoder

In [26]:
bn = Bottleneck(256, 128, upsample=True)

In [27]:
features[0].shape

torch.Size([256, 4, 32, 32])

In [28]:
x = torch.rand(1,256,32,32)
bn(x).shape

torch.Size([1, 128, 64, 64])

In [51]:
#export
class Upsampler(nn.Module):
    def __init__(self, sizes=[256,128,64], n_out=3):
        super().__init__()
        zsizes = zip(sizes[:-1], sizes[1:])
        self.convs = nn.Sequential(*[Bottleneck(si, sf, upsample=True) for si,sf in zsizes], 
                                   Bottleneck(sizes[-1], sizes[-1], upsample=True), 
                                   ConvBlock(sizes[-1], n_out, kernel_size=1, activation='none'))
        
    def forward(self, x):
        return self.convs(x)

In [52]:
us = Upsampler()

In [53]:
us(x).shape

torch.Size([1, 3, 256, 256])

## 4B. Irradiance Module

In [59]:
#export
class IrradianceModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.Sequential(ConvBlock(128, 64), 
                                   ConvBlock(64, 64),
                                   nn.AdaptiveMaxPool2d(1)
                                  )
        self.linear = nn.Sequential(nn.Flatten(), 
                                    nn.BatchNorm1d(64),
                                    nn.Linear(64, 1)
                                   )
    def forward(self, x):
        return self.linear(self.convs(x))

In [60]:
im = IrradianceModule()

In [62]:
x = torch.rand(2, 128, 32, 32)
im(x).shape

torch.Size([2, 1])

## Everything Together...

In [None]:
x.new_zeros()

In [None]:
class Eclipse(nn.Module):
    
    def __init__(self, n_in, horizon=5):
        self.horizon = horizon
        self.spatial_downsampler = SpatialDownsampler(n_in)
        self.temporal_encoder = TemporalBlock(256, 128)
        self.future_state = FuturePrediction(128, 128, n_gru_blocks=4, n_res_layers=4)
        self.upsampler = Upsampler()
        self.irradiance = IrradianceModule()
    
    def set_hidden(x):
        bs, ch, t, h, w = x.shape
        self.hidden_state = x.new_zeros(bs, ch, h, w)
        
    def forward(imgs):
        down_imgs = torch.stack([self.spatial_downsampler(img) for img in imgs], dim=2)
        te_imgs = self.temporal_encoder(down_imgs)
        
        preds = {'masks': [], 'irradiance': []}
        self.set_hidden(te_imgs)
        for _ in range(self.horizon):
            self.future_state()

## Export

In [63]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_model.ipynb.
Converted 01_layers.ipynb.
Converted index.ipynb.
