In [2]:
import timm
import torch
from thop import profile

### 模型索引

In [3]:
#打印所有模型
# model_list = timm.list_models()
# print(model_list)
#打印所有带预训练的模型
# model_pretrain_list = timm.list_models(pretrained=True)
# print(model_pretrain_list)
#检索特定模型
model_resnet = timm.list_models('*mobilenet*')
print(model_resnet)

['mobilenetv2_035', 'mobilenetv2_050', 'mobilenetv2_075', 'mobilenetv2_100', 'mobilenetv2_110d', 'mobilenetv2_120d', 'mobilenetv2_140', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_100_miil', 'mobilenetv3_large_100_miil_in21k', 'mobilenetv3_rw', 'mobilenetv3_small_050', 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100']


### 创建模型 计算参数量和计算量

In [2]:
# 创建模型
x = torch.randn((1, 3, 256, 256))
modle_mobilenetv2 = timm.create_model('swinv2_base_window8_256', pretrained=False)
out = modle_mobilenetv2(x)
print(out.shape)
torch.Size([1, 1000])

# 测试模型参数量和计算量
flops , params = profile(modle_mobilenetv2,inputs=(x,))
print(flops)
print(params)

torch.Size([1, 1000])
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_softmax() for <class 'torch.nn.modules.activation.Softmax'>.
15034155008.0
66955624.0


### 获取未池化和未分类模型
  |  全局池化  输入一般为NxCxHxW，输出为NxCx1x1

In [None]:
x = torch.randn((1, 3, 512, 512))
#方法1
m = timm.create_model('xception41', pretrained=False)
outfeatures = m.forward_features(x) ##直接提取网络分类层之前还未池化的特征
#方法2
m = timm.create_model('densenet121', pretrained=False)
m.reset_classifier(0, '')
#方法3
m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')

print(m(x).shape)


### 获取全局池化后输出

In [None]:
#方法1
m = timm.create_model('resnet50', pretrained=True, num_classes=0)
#方法2
m = timm.create_model('ese_vovnet19b_dw', pretrained=True)
m.reset_classifier(0)


### 改变时输入和输出的通道

In [None]:
#num_classes=100改变输出类别
# in_chans=10改变输入通道数
x = torch.randn((1, 10, 224, 224))
net = timm.create_model('swin_base_patch4_window7_224', pretrained=False,
                 num_classes=0,in_chans=10,global_pool='')
out = net(x)


### 特征图提取

In [3]:
#output_stride最后输出尺寸的缩小值
# out_indices输出特征索引

net = timm.create_model('mobilenetv3_small_050', features_only = True, output_stride=32,
                    out_indices=(1,2,3,4),pretrained=False,num_classes=0,global_pool='')
#输出各层通道数
# print(f'Feature channels: {m.feature_info.channels()}')
# #输出缩小比例
# print(f'Feature reduction: {m.feature_info.reduction()}')

print(net)
# o = net(torch.randn(2, 3, 512, 512))
# for x in o:
#     print(x.shape)


MobileNetV3Features(
  (conv_stem): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): Hardswish()
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
        (bn1): BatchNormAct2d(
          16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): ReLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): ReLU(inplace=True)
          (conv_expand): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (gate): Hardsigmoid()
        )
        (conv_pw): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
          8, eps=1e-05, momentum=0.1, 

### 其他

In [10]:
import timm
import torch
from thop import profile

net = timm.create_model('resnet101', pretrained=False,num_classes=0,global_pool='')
print(net)
# o = net(torch.randn(2, 3, 256, 256))
# print(o.shape)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (drop_block): Identity()
      (act2): ReLU(inplace=True)
      (aa): Identity()
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     

In [None]:
from timm.models.swin_transformer_v2 import BasicLayer as SwinTransformerBlock
from timm.models.mobilevit import MobileVitV2Block
from timm.models.vision_transformer import Block as TransformerBlock
import torch
import torch.nn as nn
from thop import profile
class swinblock(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.Swin_block = SwinTransformerBlock(dim=96, input_resolution=(56,56), depth=4, num_heads=4,
            window_size=7, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
            drop_path=0., norm_layer=nn.LayerNorm, downsample=None)
    def forward(self,x):
        x = self.Swin_block(x)
        return x
class vitblock(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.blocks = nn.Sequential(*[
            TransformerBlock(
                dim=96, num_heads=8, mlp_ratio=4., qkv_bias=True, init_values=None,
                drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, act_layer=nn.GELU)
            for i in range(4)])
    def forward(self,x):
        x = self.blocks(x)
        return x
x = torch.rand(3,3136,96)
# net = swinblock()vitblock
net = vitblock()
print(net)
# net2(x)

flops , params = profile(net,inputs=(x,))
print(flops)
print(params)
