In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from collections import OrderedDict
import numpy as np
import math

In [2]:
input_channels=49
output_channels=3
num_layers=15
num_features=64
scale=2
use_climatology=True,
climatology_channels=output_channels+1
relu = nn.ReLU(inplace=True)

In [49]:
def _make_conv_layers():
    layers = [
        nn.Sequential(
            nn.Conv2d(
                input_channels,
                num_features,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
            nn.ReLU(inplace=True),
        )
    ]
    for _ in range(num_layers - 1):
        layers.append(
            nn.Sequential(
                nn.Conv2d(
                    num_features, num_features, kernel_size=3, padding=1
                ),
                nn.ReLU(inplace=True),
            )
        )
    return nn.Sequential(*layers)

def _make_deconv_layers():
    layers = []
    for _ in range(num_layers - 1):
        layers.append(
            nn.Sequential(
                nn.ConvTranspose2d(
                    num_features, num_features, kernel_size=3, padding=1
                ),
                nn.ReLU(inplace=True),
                nn.Conv2d(
                    num_features, num_features, kernel_size=3, padding=1
                ),
                nn.ReLU(inplace=True),
            )
        )
    layers.append(
        nn.Sequential(
            nn.ConvTranspose2d(
                num_features,
                num_features,
                kernel_size=3,
                stride=1,
                padding=1,
                output_padding=0,
            ),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                num_features,
                input_channels,
                kernel_size=3,
                stride=1,
                padding=1,
            ),
        )
    )
    return nn.Sequential(*layers)

def _make_subpixel_conv_layer():
    return nn.Sequential(
        nn.Conv2d(
            input_channels,
            input_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        ),
        nn.ReLU(inplace=True),
        upsample,
        nn.Conv2d(
            input_channels,
            input_channels,
            kernel_size=3,
            stride=1,
            padding=1,
        ),
        nn.ReLU(inplace=True),
    )

In [50]:
conv_layers = _make_conv_layers()
deconv_layers = _make_deconv_layers()
upsample = nn.Upsample(
    scale_factor=scale, mode="bilinear", align_corners=False
)
subpixel_conv_layer = _make_subpixel_conv_layer()


if use_climatology:
    fusion_layer = nn.Sequential(
        nn.Conv2d(
            2 * input_channels + climatology_channels,
            num_features,
            kernel_size=3,
            stride=1,
            padding=1,
        ),
        nn.ReLU(inplace=True),
        nn.Conv2d(
            num_features,
            output_channels,
            kernel_size=1,
            stride=1,
            padding=0,
        ),
    )

In [23]:
x = torch.randn(64, 49, 32, 64)
x_aux = torch.randn(64, 4, 64, 128)

In [26]:
class SRCNN(nn.Module):
    def __init__(self, in_channels, out_channels, num_features, scale_factor=2):
        super(SRCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, num_features, kernel_size=9, padding=4)
        self.conv2 = nn.Conv2d(
            num_features, num_features // 2, kernel_size=5, padding=2
        )
        self.conv3 = nn.Conv2d(
            num_features // 2, out_channels, kernel_size=5, padding=2
        )
        self.relu = nn.ReLU()
        self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.conv3(x)
        x = self.upsample(x)
        return x


In [27]:
model = SRCNN(in_channels=49, out_channels=3, num_features=64)

In [28]:
model(x).shape

torch.Size([64, 3, 64, 128])

In [52]:
residual = x
residual_up = upsample(x)

In [53]:
conv_feats = []
for i in range(num_layers):
    x = conv_layers[i](x)
    if (i + 1) % 2 == 0 and len(conv_feats) < math.ceil(
        num_layers / 2
    ) - 1:
        conv_feats.append(x)

In [54]:
conv_feats_idx = 0
for i in range(num_layers):
    x = deconv_layers[i](x)
    if (i + 1 + num_layers) % 2 == 0 and conv_feats_idx < len(conv_feats):
        conv_feat = conv_feats[-(conv_feats_idx + 1)]
        conv_feats_idx += 1
        x = x + conv_feat
        x = relu(x)

In [55]:
x = x + residual
x = relu(x)

In [56]:
# Pass through subpixel convolution layer
x = subpixel_conv_layer(x)

In [57]:
# If use_climatology is True, apply the fusion layer with additional input (x_aux)
if use_climatology and (x_aux is not None):
    # Concatenate with upsampled input
    x = torch.cat([x, residual_up], dim=1)
    # Concatenate with aux data input
    x = torch.cat([x, x_aux], dim=1)
    # Pass through fusion layer
    x = fusion_layer(x)  # [Nbatch,Nchannel,Nlat,Nlon]

In [58]:
x.shape

torch.Size([64, 3, 64, 128])