In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
 
 
class STN(nn.Module):
    def __init__(self):
        super(STN, self).__init__()
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )
        # 移除fc_loc中第一个nn.Linear的输入尺寸定义
        self.fc_loc = None  # 将在forward中动态创建
 
        # 用于空间变换网络的权重和偏置初始化参数
        self.fc_loc_output_params = torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float)
 
    def forward(self, x):
        xs = self.localization(x)
        xs_size = xs.size()
        # 动态计算全连接层的输入尺寸
        fc_input_size = xs_size[1] * xs_size[2] * xs_size[3]
 
        # 根据动态计算的输入尺寸创建fc_loc
        if self.fc_loc is None:
            self.fc_loc = nn.Sequential(
                nn.Linear(fc_input_size, 32),
                nn.ReLU(True),
                nn.Linear(32, 3 * 2)
            )
            # 初始化空间变换网络的权重和偏置
            self.fc_loc[2].weight.data.zero_()
            self.fc_loc[2].bias.data.copy_(self.fc_loc_output_params)
 
        xs = xs.view(-1, fc_input_size)
        theta = self.fc_loc(xs)
        theta = theta.view(-1, 2, 3)
        grid = F.affine_grid(theta, x.size(), align_corners=True)
        x = F.grid_sample(x, grid, align_corners=True)
        return x
 
 
def test_stn(input_size=(1, 1, 32, 32)):
    stn = STN()
    input_tensor = torch.rand(input_size)
    transformed_tensor = stn(input_tensor)
    print(transformed_tensor.shape)
 
    input_image = input_tensor.numpy()[0][0]
    transformed_image = transformed_tensor.detach().numpy()[0][0]
 
    plt.figure()
    plt.subplot(1, 2, 1)
    plt.title("Original Image")
    plt.imshow(input_image, cmap='gray')
    plt.subplot(1, 2, 2)
    plt.title("Transformed Image")
    plt.imshow(transformed_image, cmap='gray')
    plt.show()
 
 
# 测试不同尺寸的输入
# test_stn(input_size=(1, 1, 32, 32))
test_stn(input_size=(1, 1, 64, 64))
 