In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# --- 简化的 StripConv 用于测试 --- #
class StripConvTest(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()
        pad = kernel_size // 2
        self.kernel_size = kernel_size

        self.conv_h = nn.Conv2d(1, 1, (1, kernel_size), padding=(0, pad), bias=False)
        self.conv_v = nn.Conv2d(1, 1, (kernel_size, 1), padding=(pad, 0), bias=False)
        self.conv_d1 = nn.Conv2d(1, 1, kernel_size, padding=pad, bias=False)  # 主对角线
        self.conv_d2 = nn.Conv2d(1, 1, kernel_size, padding=pad, bias=False)  # 副对角线

        # 初始化卷积核为全1，方便观察
        nn.init.constant_(self.conv_h.weight, 1.0)
        nn.init.constant_(self.conv_v.weight, 1.0)
        nn.init.constant_(self.conv_d1.weight, 0.0)
        nn.init.constant_(self.conv_d2.weight, 0.0)

        # 主对角线核
        for i in range(kernel_size):
            self.conv_d1.weight.data[0, 0, i, i] = 1.0
        # 副对角线核
        for i in range(kernel_size):
            self.conv_d2.weight.data[0, 0, i, kernel_size - 1 - i] = 1.0

    def forward(self, x):
        out_h = self.conv_h(x)
        out_v = self.conv_v(x)
        out_d1 = self.conv_d1(x)
        out_d2 = self.conv_d2(x)
        return out_h, out_v, out_d1, out_d2

# --- 生成 7x7 单通道输入图像 --- #
def generate_input():
    data = np.arange(49).reshape(1, 1, 7, 7).astype(np.float32)
    return torch.from_numpy(data)

# --- 打印结果为二维图像 --- #
def print_tensor(name, tensor):
    array = tensor.squeeze().detach().numpy()
    print(f"{name}:\n{np.round(array, 1)}\n")

# --- 主函数 --- #
if __name__ == "__main__":
    input_tensor = generate_input()
    model = StripConvTest(kernel_size=3)
    out_h, out_v, out_d1, out_d2 = model(input_tensor)

    print_tensor("Input", input_tensor)
    print_tensor("Horizontal Conv", out_h)
    print_tensor("Vertical Conv", out_v)
    print_tensor("Main Diagonal Conv", out_d1)
    print_tensor("Anti Diagonal Conv", out_d2)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class StripConvTest(nn.Module):
    def __init__(self, kernel_size=3):
        super().__init__()
        pad = kernel_size // 2
        self.kernel_size = kernel_size

        self.conv_h = nn.Conv2d(1, 1, (1, kernel_size), padding=(0, pad), bias=False)
        self.conv_v = nn.Conv2d(1, 1, (kernel_size, 1), padding=(pad, 0), bias=False)
        self.conv_d1 = nn.Conv2d(1, 1, kernel_size, padding=pad, bias=False)
        self.conv_d2 = nn.Conv2d(1, 1, kernel_size, padding=pad, bias=False)

        # 初始化水平方向 & 垂直方向为全1
        nn.init.constant_(self.conv_h.weight, 1.0)
        nn.init.constant_(self.conv_v.weight, 1.0)

        # 初始化对角线方向
        self.d1_mask = torch.zeros_like(self.conv_d1.weight.data)
        self.d2_mask = torch.zeros_like(self.conv_d2.weight.data)
        for i in range(kernel_size):
            self.d1_mask[0, 0, i, i] = 1.0
            self.d2_mask[0, 0, i, kernel_size - 1 - i] = 1.0

        self.conv_d1.weight.data *= self.d1_mask
        self.conv_d2.weight.data *= self.d2_mask

        # hook 限制梯度更新只在对角线
        self.conv_d1.weight.register_hook(lambda grad: grad * self.d1_mask.to(grad.device))
        self.conv_d2.weight.register_hook(lambda grad: grad * self.d2_mask.to(grad.device))

    def forward(self, x):
        h = self.conv_h(x)
        v = self.conv_v(x)
        d1 = self.conv_d1(x)
        d2 = self.conv_d2(x)
        out = (h + v + d1 + d2) / 4
        return out


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class DrConv(nn.Module):
    def __init__(self,in_channel,out_channel=1,kernel_size=3,padding='same',stride=1):
        super().__init__()
        self.kernel_size = kernel_size
        self.padding = padding
        self.stride = stride

        self.conv_h = nn.Conv2d(in_channel, out_channel, (1, kernel_size), padding=padding, bias=False,stride=stride)
        self.conv_v = nn.Conv2d(in_channel, out_channel, (kernel_size, 1), padding=padding, bias=False,stride=stride)
        self.conv_d1 = torch.nn.Parameter(
            torch.randn(out_channel,in_channel,1, kernel_size), requires_grad=True
        )
        self.conv_d2 = torch.nn.Parameter(
            torch.randn(out_channel,in_channel,1, kernel_size), requires_grad=True
        )
        nn.init.kaiming_uniform_(self.conv_d1)
        nn.init.kaiming_uniform_(self.conv_d2)
        self.eyes = torch.eye(kernel_size,requires_grad=False)
        self.reyes = torch.flip(self.eyes,[-1])

    def forward(self, x):
        h = self.conv_h(x)
        v = self.conv_v(x)
        d1 = F.conv2d(x,self.conv_d1*self.eyes,stride=self.stride,padding=self.padding)
        d2 = F.conv2d(x,self.conv_d2*self.reyes,stride=self.stride,padding=self.padding)
        return h, v, d1, d2


In [4]:
from swin import SwinBlock
from abtb import LayerNorm
import torch.nn as nn
import torch.nn.functional as F
class MDFusion(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.c3 = DrConv(in_channels, out_channels, kernel_size=3)
        self.c5 = DrConv(in_channels, out_channels, kernel_size=5)
        self.c7 = DrConv(in_channels, out_channels, kernel_size=7)
        self.ln = LayerNorm(in_channels*4)
        self.attn = nn.Sequential(
            SwinBlock(dim=4 * in_channels, input_resolution=None,num_heads=4, shift_size=0),
            SwinBlock(dim=4 * in_channels, input_resolution=None,num_heads=4, shift_size=7),
            nn.Conv2d(4 * in_channels, in_channels, kernel_size=1)
        )
        

    def forward(self, x):
        b,c,h,w = x.shape

        o11,o21,o31,o41 = self.c3(x)
        o12,o22,o32,o42 = self.c5(x)
        o13,o23,o33,o43 = self.c7(x)
        o1 = torch.concat([o11, o12, o13], dim=1)
        o2 = torch.concat([o21, o22, o23], dim=1)
        o3 = torch.concat([o31, o32, o33], dim=1)
        o4 = torch.concat([o41, o42, o43], dim=1)

        
        fs1= o1.view(b, 3, c, h, w).max(dim=1)[0] + o1.view(b, 3, c, h, w).mean(dim=1)
        fs2= o2.view(b, 3, c, h, w).max(dim=1)[0] + o2.view(b, 3, c, h, w).mean(dim=1)
        fs3= o3.view(b, 3, c, h, w).max(dim=1)[0] + o3.view(b, 3, c, h, w).mean(dim=1)
        fs4= o4.view(b, 3, c, h, w).max(dim=1)[0] + o4.view(b, 3, c, h, w).mean(dim=1)        
        
        fs = torch.concat([fs1/2, fs2/2, fs3/2, fs4/2], dim=1)
        out = self.ln(fs)
        out = self.attn(out)
        return x+out
        
        
        
        


In [6]:
x = torch.randn(2, 3, 256, 256)
model = MDFusion(in_channels=3, out_channels=3)
x = model(x)
x

tensor([[[[ 1.2225e+00,  3.5067e+00,  2.0813e+00,  ...,  4.7827e-01,
            2.3506e+00,  1.0985e+00],
          [ 2.8899e+00, -1.4897e-02,  2.8195e-02,  ...,  1.9042e+00,
            2.9225e-01,  1.1990e+00],
          [ 3.0239e+00,  2.1733e+00,  8.0982e-01,  ...,  2.5089e+00,
            2.9334e+00,  1.2483e+00],
          ...,
          [ 3.3210e+00,  2.2619e+00,  2.9278e+00,  ...,  1.7615e+00,
            1.6711e+00, -8.5959e-01],
          [ 6.6111e-01,  2.3380e+00,  2.7204e+00,  ...,  1.1581e+00,
            3.9818e-01,  7.6475e-01],
          [ 7.8300e-01,  6.4524e-01,  2.7036e+00,  ...,  1.2490e+00,
            1.1947e+00,  1.0479e+00]],

         [[ 8.3272e-01, -1.2717e+00, -2.0454e+00,  ..., -1.4773e+00,
           -3.7280e-01, -2.2769e+00],
          [-1.7073e+00, -2.6315e+00, -9.4914e-01,  ..., -1.0419e+00,
            9.2039e-01,  5.7414e-01],
          [-1.0306e+00,  8.5564e-01, -1.2150e+00,  ..., -5.8251e-01,
           -1.2181e+00, -1.2888e+00],
          ...,
     

In [None]:
def generate_data(batch_size=16):
    # 生成 16 张 7x7 图像，数值为0~1
    x = torch.rand(batch_size, 3, 7, 7)
    # 目标为全1图像
    y = torch.ones_like(x)
    return x, y


In [None]:
def train():
    model = StripConvTest(in_channel=3, out_channel=3, kernel_size=3)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_fn = nn.MSELoss()

    print("Initial d1 kernel:")
    print(model.conv_d1)

    for epoch in range(30):
        model.train()
        x, y = generate_data()
        out = model(x)
        loss = loss_fn(out, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # # 强制保持 mask（保险措施）
        # with torch.no_grad():
        #     model.conv_d1.weight.data *= model.d1_mask
        #     model.conv_d2.weight.data *= model.d2_mask

        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1:02d}: Loss={loss.item():.6f}")

    print("\nAfter training d1 kernel:")
    print(model.conv_d1)


In [None]:
if __name__ == "__main__":
    train()

In [None]:
x=torch.FloatTensor([[[[1,1,1]],
                      [[1,2,2,]]],
                     [[[2,1,1]],
                     [[2,2,2]]]])
print(x)
y=torch.eye(3)
print(y)
z=torch.matmul(x, y)
print(z)
z =x*y
print(z)

In [None]:
x=torch.nn.Parameter(torch.randn(2,1,3))
x

In [None]:
y

In [None]:
print(torch.flip(y, dims=[-1]))
y

In [2]:
import torch
from torch import nn

p = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
x = torch.randn(1, 1, 5, 5)
out = -p(-x)
print("Input:\n", x)
print("Output:\n", out)

Input:
 tensor([[[[-0.8645, -1.3713, -0.4421,  0.3379,  0.4282],
          [-1.1967, -1.4151,  0.0986,  0.7842, -0.6357],
          [-1.1211, -1.0270,  0.4290, -0.9869,  1.4205],
          [ 0.6235, -0.4959,  0.7135, -1.4688,  0.0828],
          [ 0.9833,  0.7404,  0.1762,  0.8674, -0.0630]]]])
Output:
 tensor([[[[-1.4151, -1.4151, -1.4151, -0.6357, -0.6357],
          [-1.4151, -1.4151, -1.4151, -0.9869, -0.9869],
          [-1.4151, -1.4151, -1.4688, -1.4688, -1.4688],
          [-1.1211, -1.1211, -1.4688, -1.4688, -1.4688],
          [-0.4959, -0.4959, -1.4688, -1.4688, -1.4688]]]])
