In [None]:
# default_exp model

# The model

> from the paper from [Paletta et al](https://arxiv.org/pdf/2104.12419v1.pdf)

We will try to implement as close as possible the architecture from the paper `ECLIPSE : Envisioning Cloud Induced Perturbations in Solar Energy`

![Image](images/eclipse_diagram.png)

In [None]:
#export
from eclipse_pytorch.imports import *
from eclipse_pytorch.layers import *

## 1. Spatial Downsampler
> A resnet encoder to get image features

You could use any spatial downsampler as you want, but the paper states a simple resnet arch...

In [None]:
#export
class SpatialDownsampler(nn.Module):
    
    def __init__(self, in_channels=3):
        super().__init__()
        self.conv1 = ConvBlock(in_channels, 64, kernel_size=7, stride=1)
        self.blocks = nn.Sequential(ResBlock(64, 64, kernel_size=3, stride=2), 
                                    ResBlock(64, 128, kernel_size=3, stride=2), 
                                    ResBlock(128,256, kernel_size=3, stride=2))
    
    def forward(self, x):
        return self.blocks(self.conv1(x))

In [None]:
sd = SpatialDownsampler()

In [None]:
images = [torch.rand(1, 3, 64, 64) for _ in range(4)]
features = torch.stack([sd(image) for image in images], dim=2)
features.shape

torch.Size([1, 256, 4, 8, 8])

## 2. Temporal Encoder

In [None]:
te = TemporalBlock(256, 128)

In [None]:
temp_encoded = te(features)
temp_encoded.shape

torch.Size([1, 128, 4, 8, 8])

## 3. Future State Predictions

In [None]:
fp = FuturePrediction(128, 128, n_gru_blocks=4, n_res_layers=4)

In [None]:
hidden = torch.rand(1, 128, 8, 8)
x = torch.rand(1, 4, 128, 8, 8)
fp(x, hidden).shape

torch.Size([1, 4, 128, 8, 8])

## 4A. Segmentation Decoder

In [None]:
bn = Bottleneck(256, 128, upsample=True)
bn

Bottleneck(
  (layers): Sequential(
    (conv_down_project): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (abn_down_project): Sequential(
      (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): ReLU(inplace=True)
    )
    (conv): ConvTranspose2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1), bias=False)
    (abn): Sequential(
      (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): ReLU(inplace=True)
    )
    (conv_up_project): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (abn_up_project): Sequential(
      (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): ReLU(inplace=True)
    )
    (dropout): Dropout2d(p=0.0, inplace=False)
  )
  (projection): Sequential(
    (upsample_skip_proj): Interpolate()
    (conv_skip_proj): Conv2d(256, 128, kernel_size=(1, 1), stride=(1

In [None]:
features[0].shape

torch.Size([256, 4, 8, 8])

In [None]:
x = torch.rand(1,256,32,32)
bn(x).shape

torch.Size([1, 128, 64, 64])

In [None]:
#export
class Upsampler(nn.Module):
    def __init__(self, sizes=[128,128,64], n_out=3):
        super().__init__()
        zsizes = zip(sizes[:-1], sizes[1:])
        self.convs = nn.Sequential(*[Bottleneck(si, sf, upsample=True) for si,sf in zsizes], 
                                   Bottleneck(sizes[-1], sizes[-1], upsample=True), 
                                   ConvBlock(sizes[-1], n_out, kernel_size=1, activation='none'))
        
    def forward(self, x):
        return self.convs(x)

In [None]:
us = Upsampler()

x = torch.rand(1,128,32,32)
us(x).shape

torch.Size([1, 3, 256, 256])

## 4B. Irradiance Module

In [None]:
#export
class IrradianceModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.convs = nn.Sequential(ConvBlock(128, 64), 
                                   ConvBlock(64, 64),
                                   nn.AdaptiveMaxPool2d(1)
                                  )
        self.linear = nn.Sequential(nn.Flatten(), 
                                    nn.BatchNorm1d(64),
                                    nn.Linear(64, 1)
                                   )
    def forward(self, x):
        return self.linear(self.convs(x))

In [None]:
im = IrradianceModule()

In [None]:
x = torch.rand(2, 128, 32, 32)
im(x).shape

torch.Size([2, 1])

## Everything Together...

In [None]:
#export
class Eclipse(nn.Module):
    """Not very parametric"""
    def __init__(self, n_in=3, n_out=4, horizon=5, debug=False):
        super().__init__()
        store_attr()
        self.spatial_downsampler = SpatialDownsampler(n_in)
        self.temporal_model = TemporalBlock(256, 128)
        self.future_prediction = FuturePrediction(128, 128, n_gru_blocks=4, n_res_layers=4)
        self.upsampler = Upsampler(n_out=n_out)
        self.irradiance = IrradianceModule()
    
    def zero_hidden(self, x, horizon):
        bs, ch, h, w = x.shape
        return x.new_zeros(bs, horizon, ch, h, w)
        
    def forward(self, imgs):
        x = torch.stack([self.spatial_downsampler(img) for img in imgs], dim=2)
        
        #encode temporal model
        states = self.temporal_model(x).permute(0, 2, 1, 3, 4).contiguous()
        if self.debug: print(f'{states.shape=}')
        
        #get hidden state
        present_state = states[:, -1:]
        if self.debug: print(f'{present_state.shape=}')
        
        
        # Prepare future prediction input
        hidden_state = present_state.squeeze()
        if self.debug: print(f'{hidden_state.shape=}')
        
        future_prediction_input = self.zero_hidden(hidden_state, self.horizon)
        
        # Recursively predict future states
        future_states = self.future_prediction(future_prediction_input, hidden_state)

        # Concatenate present state
        future_states = torch.cat([present_state, future_states], dim=1)
        if self.debug: print(f'{future_states.shape=}')
        
        #decode outputs
        preds = {'masks': [], 'irradiance': []}
        for state in future_states.unbind(dim=1):
            preds['masks'].append(self.upsampler(state))
            preds['irradiance'].append(self.irradiance(state))
        return preds
        

In [None]:
eclipse = Eclipse(debug=True)

In [None]:
preds = eclipse([torch.rand(2, 3, 128, 128) for _ in range(4)])

preds['masks'][0].shape, preds['irradiance'][0].shape

states.shape=torch.Size([2, 4, 128, 16, 16])
present_state.shape=torch.Size([2, 1, 128, 16, 16])
hidden_state.shape=torch.Size([2, 128, 16, 16])
future_states.shape=torch.Size([2, 6, 128, 16, 16])


(torch.Size([2, 4, 128, 128]), torch.Size([2, 1]))

## Export

In [None]:
# hide
from nbdev.export import *
notebook2script()

Converted 00_model.ipynb.
Converted 01_layers.ipynb.
Converted index.ipynb.
