<a href="https://colab.research.google.com/github/KeisukeShimokawa/papers-challenge/blob/master/src/cv/SAM/notebooks/SAM_test.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
inputs = torch.randn(10, 64, 16, 16)

print("input feature map: ", inputs.shape)

input feature map:  torch.Size([10, 64, 16, 16])


In [5]:
in_planes  = inputs.shape[1]
rel_planes = in_planes // 16
out_planes = in_planes // 4
print(in_planes, rel_planes, out_planes)

share_planes = 8
kernel_size = 7
stride = 1
dilation = 1

64 4 16


In [6]:
# step1
conv1 = nn.Conv2d(in_planes, rel_planes, kernel_size=1)
conv2 = nn.Conv2d(in_planes, rel_planes, kernel_size=1)
conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)

x1, x2, x3 = conv1(inputs), conv2(inputs), conv3(inputs)

print(f"x1: {x1.shape}")
print(f"x2: {x2.shape}")
print(f"x3: {x3.shape}")

x1: torch.Size([10, 4, 16, 16])
x2: torch.Size([10, 4, 16, 16])
x3: torch.Size([10, 16, 16, 16])


In [15]:
x1_reshape = x1.view(x1.shape[0], -1, 1, x1.shape[2] * x1.shape[3])

print(f"x1 reshape: {x1_reshape.shape}")

x1 reshape: torch.Size([10, 4, 1, 256])


pytorchのUnfold関数の出力値

$$
(N, C \times \Pi(\text { kernel_size }), L)
$$

この$L$とは以下の計算に従う。なお$d$は空間方向の次元数である。

$$
L=\prod_{d}\left[\frac{\text { spatial }_{-} \operatorname{size}[d]+2 \times \operatorname{padding}[d]-\operatorname{dilation}[d] \times\left(\operatorname{kernel}_{-} \operatorname{size}[d]-1\right)-1}{\operatorname{stride}[d]}+1\right]
$$

In [18]:
# paddingしたサイズと、Kernel_size-1のサイズが等しくなるので
# 単純に入力された空間次元数をかけ合わせたものになる。(16x15=256)
(((16+3*2) + 2 * 0 - (kernel_size - 1) - 1) + 1) ** 2

256

In [14]:
4 * kernel_size ** 2

196

In [17]:
unfold_x2 = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride)
pad = nn.ReflectionPad2d(kernel_size // 2)

x2_padded = pad(x2)
print(f"x2 padded: {x2_padded.shape}")

x2_unfold = unfold_x2(x2_padded)
print(f"x2 unfold: {x2_unfold.shape}")

x2_reshape = x2_unfold.view(x1_reshape.shape[0], -1, 1, x1_reshape.shape[-1])
print(f"s2 reshape: {x2_reshape.shape}")

x2 padded: torch.Size([10, 4, 22, 22])
x2 unfold: torch.Size([10, 196, 256])
s2 reshape: torch.Size([10, 196, 1, 256])


In [19]:
print(f"x1 reshape: {x1_reshape.shape}")
print(f"x2 reshape: {x2_reshape.shape}")

x1 reshape: torch.Size([10, 4, 1, 256])
x2 reshape: torch.Size([10, 196, 1, 256])


In [25]:
w_concat = torch.cat((x1_reshape, x2_reshape), dim=1)
print(f"w concat: {w_concat.shape}")

conv_w = nn.Sequential(
    nn.BatchNorm2d(rel_planes * (1 + pow(kernel_size, 2))),
    nn.ReLU(inplace=True),
    nn.Conv2d(rel_planes * (1 + pow(kernel_size, 2)), out_planes // share_planes, kernel_size=1, bias=False),
    nn.BatchNorm2d(out_planes // share_planes),
    nn.ReLU(inplace=True),
    nn.Conv2d(out_planes // share_planes, pow(kernel_size, 2) * out_planes // share_planes, kernel_size=1)
)

w_conv = conv_w(w_concat)
print(f"w conv: {w_conv.shape}")

w_reshape = w_conv.view(x1_reshape.shape[0], -1, pow(kernel_size, 2), x1_reshape.shape[-1])
print(f"w reshape: {w_reshape.shape}")

w concat: torch.Size([10, 200, 1, 256])
w conv: torch.Size([10, 98, 1, 256])
w reshape: torch.Size([10, 2, 49, 256])


In [31]:
from torch.nn.modules.utils import _pair


pad_mode = 1


kernel_size_pair = _pair(kernel_size)
stride_pair = _pair(stride)
padding_pair = _pair(kernel_size // 2)
dilation_pair = _pair(dilation)

print(kernel_size_pair)
print(stride_pair)
print(padding_pair)
print(dilation_pair)
print(pad_mode)

(7, 7)
(1, 1)
(3, 3)
(1, 1)
1


## Autual Sample

In [0]:
import torch
import torch.nn as nn


class SAM(nn.Module):
    def __init__(self, sa_type, in_planes, rel_planes, out_planes, share_planes, kernel_size=3, stride=1, dilation=1):
        super(SAM, self).__init__()
        self.sa_type, self.kernel_size, self.stride = sa_type, kernel_size, stride
        self.conv1 = nn.Conv2d(in_planes, rel_planes, kernel_size=1)
        self.conv2 = nn.Conv2d(in_planes, rel_planes, kernel_size=1)
        self.conv3 = nn.Conv2d(in_planes, out_planes, kernel_size=1)
        self.conv_w = nn.Sequential(nn.BatchNorm2d(rel_planes * (pow(kernel_size, 2) + 1)),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(rel_planes * (pow(kernel_size, 2) + 1), out_planes // share_planes, kernel_size=1, bias=False),
                                    nn.BatchNorm2d(out_planes // share_planes),
                                    nn.ReLU(inplace=True),
                                    nn.Conv2d(out_planes // share_planes, pow(kernel_size, 2) * out_planes // share_planes, kernel_size=1))
        self.unfold_i = nn.Unfold(kernel_size=1, dilation=dilation, padding=0, stride=stride)
        self.unfold_j = nn.Unfold(kernel_size=kernel_size, dilation=dilation, padding=0, stride=stride)
        self.pad = nn.ReflectionPad2d(kernel_size // 2)
        # self.aggregation = Aggregation(kernel_size, stride, (dilation * (kernel_size - 1) + 1) // 2, dilation, pad_mode=1)

    def forward(self, x):
        x1, x2, x3 = self.conv1(x), self.conv2(x), self.conv3(x)
        if self.sa_type == 0:  # pairwise
            p = self.conv_p(position(x.shape[2], x.shape[3], x.is_cuda))
            w = self.softmax(self.conv_w(torch.cat([self.subtraction2(x1, x2), self.subtraction(p).repeat(x.shape[0], 1, 1, 1)], 1)))
        else:  # patchwise
            if self.stride != 1:
                x1 = self.unfold_i(x1)
            x1 = x1.view(x.shape[0], -1, 1, x.shape[2]*x.shape[3])
            x2 = self.unfold_j(self.pad(x2)).view(x.shape[0], -1, 1, x1.shape[-1])
            w = self.conv_w(torch.cat([x1, x2], 1)).view(x.shape[0], -1, pow(self.kernel_size, 2), x1.shape[-1])
        print(x3.shape)
        print(w.shape)
        return x

In [36]:
sam_sample = SAM(sa_type=1,
                 in_planes=in_planes,
                 rel_planes=rel_planes,
                 out_planes=out_planes,
                 share_planes=share_planes,
                 kernel_size=kernel_size,
                 stride=stride,
                 dilation=dilation)

sam_sample

SAM(
  (conv1): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv2): Conv2d(64, 4, kernel_size=(1, 1), stride=(1, 1))
  (conv3): Conv2d(64, 16, kernel_size=(1, 1), stride=(1, 1))
  (conv_w): Sequential(
    (0): BatchNorm2d(200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): ReLU(inplace=True)
    (2): Conv2d(200, 2, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (3): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU(inplace=True)
    (5): Conv2d(2, 98, kernel_size=(1, 1), stride=(1, 1))
  )
  (unfold_i): Unfold(kernel_size=1, dilation=1, padding=0, stride=1)
  (unfold_j): Unfold(kernel_size=7, dilation=1, padding=0, stride=1)
  (pad): ReflectionPad2d((3, 3, 3, 3))
)

In [39]:
_ = sam_sample(inputs)

torch.Size([10, 16, 16, 16])
torch.Size([10, 2, 49, 256])
