In [1]:
import torch
import torch.nn as nn
from collections import OrderedDict
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class ResidualUnit(nn.Module):
    def __init__(self, in_channels, use_bn = False):
        super().__init__()
        self.use_bn = use_bn
        if self.use_bn:
            self.bn1 = nn.BatchNorm2d(in_channels)
            self.bn2 = nn.BatchNorm2d(in_channels)    
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU() # maybe not be required

    def forward(self, x):
        residual = x
        out = x
        if self.use_bn:
            out = self.bn1(x)
        out = self.relu1(out)
        out = self.conv1(out)
        if self.use_bn:
            out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out)
        out += residual
        return out

class ResidualPipeline(nn.Module):
    def __init__(self, in_channels, n_units, use_bn = False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = nn.Conv2d(64, 2, kernel_size=3, stride=1, padding=1, bias=True)
        layers = []
        for i in range(n_units):
            layers.append(ResidualUnit(64, use_bn))
        self.resnet_stack = nn.Sequential(*layers)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.resnet_stack(x)
        x = self.relu(x) # not there in original paper
        x = self.conv2(x)
        return x
    
class STResnet(nn.Module):
    def __init__(self, c_channel, p_channel, t_channel, n_residual_units, map_dim, use_bn = False):
        super().__init__()
        self.c_pipe = ResidualPipeline(c_channel, n_residual_units, use_bn)
        self.p_pipe = ResidualPipeline(p_channel, n_residual_units, use_bn)
        self.t_pipe = ResidualPipeline(t_channel, n_residual_units, use_bn)
        # 1 dimension for batch processing, this class cannot process unbatched data
        self.w_c = nn.Parameter(torch.randn(1, *map_dim))
        self.w_p = nn.Parameter(torch.randn(1, *map_dim))
        self.w_t = nn.Parameter(torch.randn(1, *map_dim))
        self.tanh = nn.Tanh()
        
    def forward(self, x_c, x_p, x_t):
        y_c = self.c_pipe(x_c)
        y_p = self.p_pipe(x_p)
        y_t = self.t_pipe(x_t)
        # Fusion
        y = self.w_c*y_c + self.w_p*y_p + self.w_t*y_t # Eliment wise product (Hadamard Product)
        y = self.tanh(y)
        return y

In [3]:
from torch.utils.data import DataLoader
from BikeNYC import BikeNYCDataset
bikenyc_dataset = BikeNYCDataset('./Datasets/BikeNYC/flow_data.npy',4,1,1,0.8,True)
train_dataloader = DataLoader(bikenyc_dataset, batch_size=64, shuffle=True)
x_c, x_p, x_t, y  = next(iter(train_dataloader))
print(x_c.shape, x_p.shape, x_t.shape, y.shape)

min: 0.0 max: 737.0
torch.Size([64, 8, 21, 12]) torch.Size([64, 2, 21, 12]) torch.Size([64, 2, 21, 12]) torch.Size([64, 2, 21, 12])


In [4]:
model = STResnet(8,2,2,4,(2,21,12),True).to(device)
# for name, _ in model.named_parameters():
#     print(name)
out = model(x_c, x_p, x_t)
print(out.shape)

torch.Size([64, 2, 21, 12])
