## <font color=blue> 获取模型中间层输出

#### 获取模型中间层输出一（IntermediateLayerGetter），只能获取一级模块

In [None]:
import torch
import torchvision
m = torchvision.models.resnet18(pretrained=True)
 # extract layer1 and layer3, giving as names `feat1` and feat2`
new_m = torchvision.models._utils.IntermediateLayerGetter(m,{'layer1': 'feat1', 'layer2': 'feat2', 'layer3': 'feat3', 'layer4': 'feat4'})
out = new_m(torch.rand(1, 3, 224, 224))
# print(m)
print([(k, v.shape) for k, v in out.items()])

#### hook钩子获取模型中间层输出

In [None]:
import timm,torch
import torch.nn as nn
from model.model_arc.backbone import *

class test_hook(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.features_in_hook= []
        self.features_out_hook = []
        # self.backbone = timm.create_model("mobilevitv2_200",
        #         features_only=True, output_stride=32,
        #          out_indices=(1,2,3,4), pretrained=False, num_classes=0, global_pool='')
        self.backbone = my_Swin_512()
    
    def forward(self,x):
        hook = self.backbone.layers.blocks.register_forward_hook(hook=self.forward_hook)
        hook1 = self.backbone.layers[2].blocks.register_forward_hook(hook=self.forward_hook)
        hook2 = self.backbone.layers[3].blocks.register_forward_hook(hook=self.forward_hook)
        x = self.backbone(x)
        hook1.remove()
        hook2.remove()
        hook.remove()
        self.features_out_hook
        return x,self.features_out_hook

    def forward_hook(self,module,data_input,data_output):
        self.features_in_hook.append(data_input)
        self.features_out_hook.append(data_output)

net = test_hook()
x = torch.rand(1,3,224,224)
print(net)
# for (name, module) in net.named_modules():
#     print(name)
# out1 = net(x)

#### torchvision FX提取

In [None]:
import timm,torch
from torchvision.models.feature_extraction import get_graph_node_names, create_feature_extractor

model = timm.create_model("swin_base_patch4_window7_224", pretrained=False, exportable=True)
# print(model)

# nodes, _ = get_graph_node_names(model)
# print(nodes)

features = {'layers.3.drop2': 'out'}
feature_extractor = create_feature_extractor(model, return_nodes=features)
print(feature_extractor)
x = torch.rand(1,3,224,224)
out=feature_extractor(x)
print(out["out"].shape)

## <font color=blue> 模型库

#### 模型操作

In [None]:
#打印整个模型和名字
# for (name, module) in vit.named_modules():
#     print(name)

#打印模型参数
# print(list(vit.encoder_layer_0.ln_1.named_parameters())[0])
# print(vit.state_dict()['encoder.layers.encoder_layer_0.ln_1.weight'])

# 截取模型中部分
# net2 = torch.nn.Sequential(*list(vit.modules())[:2])
# net2 = vit.encoder

#### 本地库引用

In [None]:
from model.model_arc.backbone import *
import torch,timm

net = my_Swin_512()

#### vit_pytorch库引用

In [None]:
# import torch
# from vit_pytorch.mobile_vit import MobileViT

# mbvit_xs = MobileViT(
#     image_size = (256, 256),
#     dims = [96, 120, 144],
#     channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
#     num_classes = 1000
# )

# img = torch.randn(1, 3, 256, 256)

# pred = mbvit_xs(img) # (1, 1000)
# print(pred.shape)
# print(mbvit_xs)


#### torchvision库引用

In [None]:
#torchvision.models 模型库操作
import torch
import torchvision.models as models

net =  models.MobileNetV2()

print(net)
# x = torch.rand(1,3,224,224)
# print(vit(x).shape)

#### timm库引用

In [None]:
import timm
import torch

##  mobilevit_s
net = timm.create_model("mobilenetv3_small_050",)
##  swinv2
net = timm.models.swin_transformer_v2.swinv2_base_window8_256()

print(net)
o = net(torch.randn(2, 3, 512, 512))
print(o.shape)

#### smp库引用

In [None]:
import segmentation_models_pytorch as smp
model = smp.Unet(encoder_weights="imagenet",classes=2)
print(model)

#### 模型计算量测试

In [None]:
import torch
import torch.nn as nn
from thop import profile

import timm
import torchvision.models as models
from model.model_arc.my_model import *
# from vit_pytorch.mobile_vit import MobileViT
import segmentation_models_pytorch as smp

#timm库
# net = timm.create_model('resnet18', pretrained=False)
#torchvision
# net =  models.resnet50(pretrained=False)
#本地库
# net = my_timm_swin_Hook()
#smp
# net = smp.DeepLabV3Plus(encoder_weights="imagenet",classes=2)

#模型块测试


# print(net)

# 测试模型参数量和计算量
x = torch.rand(1,3,256,256)
flops , params = profile(net,inputs=(x,))
print(flops)
print(params)

In [None]:
from keras import  Sequential
from keras import  callbacks
from keras.callbacks import ModelCheckpoint
from keras.models import load_model
from keras.utils.vis_utils import plot_model
from keras.layers import Conv1D, MaxPooling1D, GlobalAveragePooling1D, LSTM,Dense,Bidirectional,Dropout,Flatten
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
