In [10]:
import torch
import numpy as np
import torch.nn as nn
import torchvision.transforms as transforms

import scipy.io as scio

In [11]:
subject_size = (50, 50)
context_size = (200, 200)

In [12]:
def net_branch(img_size):
    return nn.Sequential(
               transforms.Resize(img_size),
               nn.Conv2d(3, 96, (11, 1), stride=(4, 1)),
               nn.ReLU(),
               nn.BatchNorm2d(96),
               nn.Conv2d(96, 96, (1, 11), stride=(1, 4)),
               nn.ReLU(),
               nn.BatchNorm2d(96),
               nn.MaxPool2d(3, stride=2),
               
               nn.Conv2d(96, 256, (1, 5), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(256),
               nn.Conv2d(256, 256, (5, 1), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(256),
               nn.MaxPool2d(3, stride=2),
    
               nn.Conv2d(256, 384, (1, 3), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(384),
               nn.Conv2d(384, 384, (3, 1), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(384),
               
               nn.Conv2d(384, 384, (1, 3), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(384),
               nn.Conv2d(384, 384, (3, 1), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(384),
    
               nn.Conv2d(384, 256, (1, 3), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(256),
               nn.Conv2d(256, 256, (3, 1), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(256),
               nn.MaxPool2d(3, stride=2)
        )

In [13]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.subject = net_branch(subject_size)
        self.context = net_branch(context_size)
        
        self.fusion = nn.Sequential(
        )

    def forward(self, s, c):
        s = self.subject(s)
        s = torch.flatten(s, start_dim=1)
        
        c = self.context(c)
        c = torch.flatten(c, start_dim=1)

        x = torch.cat((s, c), dim=1)
        x = self.fusion(x)
        return x

In [14]:
n = Net()

In [15]:
print(n)

Net(
  (subject): Sequential(
    (0): Resize(size=(50, 50), interpolation=bilinear, max_size=None, antialias=True)
    (1): Conv2d(3, 96, kernel_size=(11, 1), stride=(4, 1))
    (2): ReLU()
    (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): Conv2d(96, 96, kernel_size=(1, 11), stride=(1, 4))
    (5): ReLU()
    (6): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Conv2d(96, 256, kernel_size=(1, 5), stride=(1, 1), padding=same)
    (9): ReLU()
    (10): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(256, 256, kernel_size=(5, 1), stride=(1, 1), padding=same)
    (12): ReLU()
    (13): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (14): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (15): Conv2d(256,