In [30]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
from vit_pytorch import ViT

In [26]:
resnet18 = models.resnet18(weights=None) 
resnet18.conv1  = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
resnet18.avgpool  = nn.AdaptiveAvgPool2d(output_size=(2, 2))
resnet18.fc  = nn.Identity()

In [27]:
resnet18

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [28]:
inp = torch.randn(1, 1, 256, 256)
resnet18(inp).shape

torch.Size([1, 2048])

In [36]:
vit = ViT(
        image_size = 256,
        channels = 1,
        patch_size = 32,
        num_classes = 5,
        dim = 1024,
        depth = 3,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )
vit.mlp_head = nn.Identity()
vit

ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=1024, out_features=1024, bias=True)
    (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-2): 3 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (dropout): Dropout(p=0.1, inplace=False)
          (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
          (to_out): Sequential(
            (0): Linear(in_features=1024, out_features=1024, bias=True)
            (1): Dropout(p=0.1, inplace=False)
          )
        )
        (1): FeedForward(
          (net): Sequential(
     

In [35]:
inp = torch.randn(1, 1, 256, 256)
vit(inp).shape

torch.Size([1, 1024])

In [None]:
def replace_bn_with_identity(module:nn.Module):
    # 遍历当前模块的所有子模块
    for name, child in module.named_children(): 
        # 如果是 BN 层，则替换为 Identity 
        if isinstance(child, nn.BatchNorm2d):
            setattr(module, name, nn.Identity())
        else:
            # 对非 BN 子模块递归调用
            replace_bn_with_identity(child)

def initialize_weights(module:nn.modules):
    # 遍历当前模块的所有子模块
        for name, child in module.named_children(): 
            # 如果是 BN 层，则替换为 Identity 
            if isinstance(child, nn.Conv2d) or isinstance(child, nn.Linear):
                # nn.init.orthogonal_(module.weight, nn.init.calculate_gain('relu'))
                nn.init.xavier_uniform_(child.weight)
                # nn.init.kaiming_uniform_(module.weight)
                if child.bias is not None:
                    nn.init.constant_(child.bias, 0)
            else:
                replace_bn_with_identity(child)
class Hybrid_RESNET18_VIT3_FC(nn.Module):
    def __init__(self, input_channels=1, act_num=5, use_softmax=True):
        super().__init__()
        self.input_shape = (input_channels, 256, 256)
        # resnet18构造
        self.resnet18 = models.resnet18(weights=None) 
        replace_bn_with_identity(self.resnet18)
        self.resnet18.conv1  = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet18.fc  = nn.Identity()  # 使用 Identity 替代 fc 层
        self.use_softmax = use_softmax
        # VIT 构造
        self.vit = ViT(
                image_size = 256,
                channels = input_channels,
                patch_size = 32,
                num_classes = 5,
                dim = 1024,
                depth = 3,
                heads = 16,
                mlp_dim = 2048,
                dropout = 0.1,
                emb_dropout = 0.1
            )
        self.vit.mlp_head = nn.Linear(1024, 512)
        self.use_softmax = use_softmax
        
        self.hybrid_linear = nn.Linear(1024, 1024)
        
        self.critic_linear = nn.Linear(1024, 1)
        self.actor_linear = nn.Linear(1024, act_num)
        initialize_weights(self)

    def forward(self, x:torch.Tensor):
        for k, shape in enumerate(self.input_shape):
            if x.shape[k+1] != shape:
                raise ValueError(f"Input shape should be {self.input_shape}, got {x.shape[1:]}")
        # [b, 512]
        x1 = self.resnet18(x)
        # [b, 512]
        x2 = self.vit(x)
        # [b, 1024]
        x = torch.cat([x1, x2], dim=1)
        x = F.relu(self.hybrid_linear(x))
        
        if self.use_softmax:
            a = F.softmax(self.actor_linear(x), dim=1)
        else:
            a = self.actor_linear(x)
        return a, self.critic_linear(x)

In [38]:
hy = Hybrid_RESNET18_VIT3_FC()

In [39]:
inp = torch.randn(1, 1, 256, 256)
act, v = hy(inp)

In [42]:
act.shape
v.shape

torch.Size([1, 1])