In [1]:
import torch

In [2]:
model_path = "test.pth"

In [3]:
# 加载权重文件
checkpoint = torch.load( model_path,map_location='cpu')

# 查看文件结构
print("Keys in checkpoint:", checkpoint.keys())

# 如果是模型权重，查看权重键名
if 'model' in checkpoint:
    state_dict = checkpoint['model']
    print("\nFirst 10 weight keys:")
    for i, key in enumerate(list(state_dict.keys())):
        if i < 10:
            print(key)
        else:
            break

# 检查权重形状
print("\nWeight shapes:")
for key, tensor in state_dict.items():
    print(f"{key}: {tuple(tensor.shape)}")

Keys in checkpoint: dict_keys(['iteration', 'model'])

First 10 weight keys:
spectrogram_extractor.stft.conv_real.weight
spectrogram_extractor.stft.conv_imag.weight
logmel_extractor.melW
bn0.weight
bn0.bias
bn0.running_mean
bn0.running_var
bn0.num_batches_tracked
conv_block1.conv1.weight
conv_block1.conv2.weight

Weight shapes:
spectrogram_extractor.stft.conv_real.weight: (513, 1, 1024)
spectrogram_extractor.stft.conv_imag.weight: (513, 1, 1024)
logmel_extractor.melW: (513, 64)
bn0.weight: (64,)
bn0.bias: (64,)
bn0.running_mean: (64,)
bn0.running_var: (64,)
bn0.num_batches_tracked: ()
conv_block1.conv1.weight: (64, 1, 3, 3)
conv_block1.conv2.weight: (64, 64, 3, 3)
conv_block1.bn1.weight: (64,)
conv_block1.bn1.bias: (64,)
conv_block1.bn1.running_mean: (64,)
conv_block1.bn1.running_var: (64,)
conv_block1.bn1.num_batches_tracked: ()
conv_block1.bn2.weight: (64,)
conv_block1.bn2.bias: (64,)
conv_block1.bn2.running_mean: (64,)
conv_block1.bn2.running_var: (64,)
conv_block1.bn2.num_batches_t

In [1]:
import torch.nn as nn
import re

class ReconstructedCNN14(nn.Module):
    def __init__(self, state_dict):
        super().__init__()
        self.layers = nn.ModuleDict()

        # 自动解析权重结构
        conv_pattern = re.compile(r'conv_block(\d+)\.(\d+)\.(weight|bias)')
        linear_pattern = re.compile(r'fc\.(weight|bias)')

        # 识别卷积块
        conv_blocks = {}
        for key in state_dict:
            if conv_match := conv_pattern.match(key):
                block_idx, layer_idx, param_type = conv_match.groups()
                block_idx = int(block_idx)

                if block_idx not in conv_blocks:
                    conv_blocks[block_idx] = []

                # 获取输入/输出通道数
                if param_type == 'weight':
                    out_channels, in_channels = state_dict[key].shape[:2]
                    conv_blocks[block_idx].append((in_channels, out_channels))

        # 构建卷积块
        for block_idx, channels in conv_blocks.items():
            block = nn.Sequential()
            for i, (in_ch, out_ch) in enumerate(channels):
                # 每对通道数对应一个卷积层
                block.add_module(f"conv{i*2}", nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1))
                block.add_module(f"bn{i*2+1}", nn.BatchNorm2d(out_ch))
                block.add_module(f"relu{i*2+2}", nn.ReLU())

            # 添加池化层（除了最后一个块）
            if block_idx < max(conv_blocks.keys()):
                block.add_module("pool", nn.MaxPool2d(kernel_size=2, stride=2))

            self.layers[f"conv_block{block_idx}"] = block

        # 构建全局池化和分类器
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))

        # 识别全连接层
        fc_in_features = None
        for key in state_dict:
            if linear_pattern.match(key) and 'weight' in key:
                fc_in_features = state_dict[key].shape[1]
                num_classes = state_dict[key].shape[0]
                break

        if fc_in_features:
            self.dropout = nn.Dropout(0.5)
            self.fc = nn.Linear(fc_in_features, num_classes)
        else:
            raise ValueError("FC layer weights not found")

    def forward(self, x):
        # 按顺序执行所有卷积块
        for i in range(1, len(self.layers) + 1):
            x = self.layers[f"conv_block{i}"](x)

        x = self.global_pool(x)
        x = x.flatten(1)
        x = self.dropout(x)
        x = self.fc(x)
        return x