In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 1. 定义基础残差块
class BasicBlock(nn.Module):
    expansion = 1  # 输出通道扩展倍数（Bottleneck 会用 4）

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        # 卷积层 1
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        # 卷积层 2
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # 下采样（如果输入输出维度不一致，需要调整 shortcut）
        self.downsample = downsample

    def forward(self, x):
        identity = x  # 保存输入

        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        # 如果需要调整输入维度，走 downsample
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity  # 残差连接
        out = F.relu(out)
        return out


# 2. 定义 ResNet 主体
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        """
        block: 残差块类型（BasicBlock 或 Bottleneck）
        layers: 每个 stage 的 block 数量，例如 [2,2,2,2] 对应 ResNet-18
        num_classes: 分类类别数
        """
        super(ResNet, self).__init__()
        self.in_channels = 64

        # stem
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # 4 个 stage
        self.layer1 = self._make_layer(block, 64,  layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        # 全局平均池化 + 全连接层
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride):
        """
        构建一个 stage（包含多个残差块）
        out_channels: 输出通道数
        blocks: 残差块数量
        stride: 第一个 block 的 stride（是否下采样）
        """
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


# 3. 构建 ResNet-18
def resnet18(num_classes=1000):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes)


# 测试网络
if __name__ == "__main__":
    model = resnet18(num_classes=10)
    x = torch.randn(1, 3, 224, 224)
    y = model(x)
    print(y.shape)  # torch.Size([1, 10])


torch.Size([1, 10])


In [None]:
nc: 1000  # 主任务分类数

stem:
  - {from: [0], module: Conv, args: {out_channels: 64, kernel_size: 7, stride: 2, padding: 3}}
  - {from: [-1], module: BatchNorm, args: {}}
  - {from: [-1], module: ReLU, args: {}}
  - {from: [-1], module: MaxPool, args: {kernel_size: 3, stride: 2, padding: 1}}

backbone:
  # Stage 1
  - {from: [-1], module: ResBlock, args: {channels: 64, stride: 1, dilation: 1}}
  - {from: [-1], module: ResBlock, args: {channels: 64, stride: 1, dilation: 1}}
  # Stage 2
  - {from: [-1], module: ResBlock, args: {channels: 128, stride: 2, dilation: 1}}
  - {from: [-1], module: ResBlock, args: {channels: 128, stride: 1, dilation: 1}}
  # Stage 3
  - {from: [-1], module: ResBlock, args: {channels: 256, stride: 2, dilation: 1}}
  - {from: [-1], module: ResBlock, args: {channels: 256, stride: 1, dilation: 1}}
  # Stage 4
  - {from: [-1], module: ResBlock, args: {channels: 512, stride: 2, dilation: 1}}
  - {from: [-1], module: ResBlock, args: {channels: 512, stride: 1, dilation: 1}}

head:
  - {from: [-1], module: GlobalAvgPool, args: {}}
  - {from: [-1], module: FC, args: {out_features: 1000}}  # 主任务分类头
  - {from: [-2], module: FC, args: {out_features: 10}}    # 辅助任务分类头


In [None]:
import torch
import torch.nn as nn
import yaml

# ------------------------------
# 基础模块
# ------------------------------
class ResBlock(nn.Module):
    def __init__(self, channels, stride=1, dilation=1):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, stride=stride, padding=dilation, dilation=dilation, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, 3, stride=1, padding=dilation, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        out = self.relu(out)
        return out

# ------------------------------
# 模块工厂
# ------------------------------
def module_factory(module_name, args, prev_out=None):
    modules = {
        'Conv': lambda: nn.Conv2d(**args),
        'BatchNorm': lambda: nn.BatchNorm2d(args.get('num_features', prev_out.shape[1])),
        'ReLU': lambda: nn.ReLU(inplace=True),
        'MaxPool': lambda: nn.MaxPool2d(**args),
        'GlobalAvgPool': lambda: nn.AdaptiveAvgPool2d((1, 1)),
        'FC': lambda: nn.Linear(args.get('in_features', prev_out.shape[1]), args['out_features']),
        'ResBlock': lambda: ResBlock(**args)
    }
    if module_name not in modules:
        raise ValueError(f"Unknown module: {module_name}")
    return modules[module_name]()

# ------------------------------
# 支持多输入多输出的网络
# ------------------------------
class YAMLNet(nn.Module):
    def __init__(self, yaml_path, in_channels=3, verbose=True):
        super().__init__()
        with open(yaml_path, 'r') as f:
            self.cfg = yaml.safe_load(f)

        self.layers = nn.ModuleList()
        self.from_list = []
        chs = [in_channels]

        # 构建层
        for layer_cfg in self.cfg.get('stem', []) + self.cfg.get('backbone', []) + self.cfg.get('head', []):
            f = layer_cfg['from']  # 列表形式
            m_name = layer_cfg['module']
            args = layer_cfg['args'].copy()
            prev_out = None if len(f)==0 else torch.zeros(1, chs[f[0]], 1, 1)  # 临时推测通道
            layer = module_factory(m_name, args, prev_out)
            self.layers.append(layer)
            self.from_list.append(f)

            # 更新输出通道
            if 'out_channels' in args:
                c2 = args['out_channels']
            elif 'channels' in args:
                c2 = args['channels']
            elif 'out_features' in args:
                c2 = args['out_features']
            else:
                c2 = chs[-1]
            chs.append(c2)

            if verbose:
                print(f"{len(self.layers)-1:03}: {m_name}, from {f}, args={args}, out_ch={c2}")

    def forward(self, x):
        outputs = {0: x}
        out_heads = []
        head_start_idx = len(self.layers) - len(self.cfg['head'])
        for i, layer in enumerate(self.layers):
            f = self.from_list[i]
            inp = torch.cat([outputs[j] for j in f], dim=1) if len(f) > 1 else outputs[f[0]]
            out = layer(inp)
            outputs[i+1] = out

            # 收集 head 输出
            if i >= head_start_idx:
                out_heads.append(out)

        return tuple(out_heads) if len(out_heads) > 1 else out_heads[0]

# ------------------------------
# 测试
# ------------------------------
if __name__ == "__main__":
    model = YAMLNet("resnet18_multi_task.yaml", in_channels=3, verbose=True)
    x = torch.randn(1, 3, 224, 224)
    y = model(x)
    if isinstance(y, tuple):
        for i, out in enumerate(y):
            print(f"Output head {i} shape:", out.shape)
    else:
        print("Output shape:", y.shape)


In [None]:
nc: 1000  # 分类数

stem:
  - {type: "image", channels: 3, height: 224, width: 224}  # 图像输入
  - {type: "vector", channels: 128}                        # 音频特征向量

backbone:
  # Stage 1 图像处理
  - {from: [0], module: Conv, args: {out_channels: 64, kernel_size: 7, stride: 2, padding: 3}}
  - {from: [-1], module: BatchNorm, args: {}}
  - {from: [-1], module: ReLU, args: {}}
  - {from: [-1], module: MaxPool, args: {kernel_size: 3, stride: 2, padding: 1}}

  # Stage 2 融合音频
  - {from: [-1, 1], module: ResBlock, args: {channels: 128, stride: 1, dilation: 1}}

head:
  - {from: [-1], module: FC, args: {out_features: 1000}}  # 最终输出层


In [None]:
nc: 1000  # 分类数

stem:
  - {type: "image", channels: 3, height: 224, width: 224}  # 图像输入
  - {type: "vector", channels: 128}                        # 音频特征向量

backbone:
  # Stage 1 图像处理
  - {from: [0], module: Conv, args: {out_channels: 64, kernel_size: 7, stride: 2, padding: 3}}
  - {from: [-1], module: BatchNorm, args: {}}
  - {from: [-1], module: ReLU, args: {}}
  - {from: [-1], module: MaxPool, args: {kernel_size: 3, stride: 2, padding: 1}}

  # Stage 2 融合音频
  - {from: [-1, 1], module: ResBlock, args: {channels: 128, stride: 1, dilation: 1}}

head:
  - {from: [-1], module: FC, args: {out_features: 1000}}  # 最终输出


In [None]:
import torch
import torch.nn as nn
import yaml

# ----------------------------
# 简单 ResBlock 示例
# ----------------------------
class ResBlock(nn.Module):
    def __init__(self, channels, stride=1, dilation=1):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=stride,
                               padding=dilation, dilation=dilation, bias=False)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1,
                               padding=dilation, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += identity
        out = self.relu(out)
        return out

# ----------------------------
# 简单 FC 层，可包含 flatten
# ----------------------------
class FC(nn.Module):
    def __init__(self, out_features):
        super().__init__()
        self.fc = nn.Linear(in_features=None, out_features=out_features)  # in_features 动态设置
        self.out_features = out_features

    def forward(self, x):
        if x.ndim > 2:
            x = torch.flatten(x, 1)
        # 动态初始化 in_features
        if self.fc.in_features is None:
            self.fc = nn.Linear(x.shape[1], self.out_features).to(x.device)
        return self.fc(x)

# ----------------------------
# 模型构建函数
# ----------------------------
class DynamicNet(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        layers_cfg = cfg['backbone'] + cfg['head']
        self.from_list = [l['from'] for l in layers_cfg]
        self.layers = nn.ModuleList()
        for l in layers_cfg:
            module_cls = globals()[l['module']]
            self.layers.append(module_cls(**l['args']))

    def forward(self, inputs):
        # inputs: list of input tensors, 对应 stem
        outputs = {i: inp for i, inp in enumerate(inputs)}
        for i, layer in enumerate(self.layers):
            f = self.from_list[i]
            f = f if isinstance(f, list) else [f]
            x = torch.cat([outputs[j] for j in f], dim=1) if len(f) > 1 else outputs[f[0]]
            out = layer(x)
            outputs[len(outputs)] = out
        return out  # 只返回 head 最终输出

# ----------------------------
# 测试
# ----------------------------
if __name__ == "__main__":
    # 加载 YAML
    with open("resnet_multiinput.yaml") as f:
        cfg = yaml.safe_load(f)

    # 创建模型
    model = DynamicNet(cfg)

    # 假设输入：图像 + 音频向量
    img = torch.randn(2, 3, 224, 224)
    audio = torch.randn(2, 128, 56, 56)  # 映射到图像 feature map 尺寸
    out = model([img, audio])

    print("输出 shape:", out.shape)  # 应为 [2, 1000]


In [None]:
nc: [10, 5]  # 两个任务的类别数

stem:
  - {type: "image", channels: 3, height: 32, width: 32}  # 输入图像1
  - {type: "image", channels: 3, height: 32, width: 32}  # 输入图像2

backbone:
  # 对图像1处理
  - {from: [0], module: Conv2d, args: {in_channels: 3, out_channels: 16, kernel_size: 3, stride: 1, padding: 1}}
  - {from: [-1], module: ReLU, args: {}}
  - {from: [-1], module: MaxPool2d, args: {kernel_size: 2, stride: 2}}

  # 对图像2处理
  - {from: [1], module: Conv2d, args: {in_channels: 3, out_channels: 16, kernel_size: 3, stride: 1, padding: 1}}
  - {from: [-1], module: ReLU, args: {}}
  - {from: [-1], module: MaxPool2d, args: {kernel_size: 2, stride: 2}}

head:
  - {from: [2], module: LinearDyn, args: {out_features: 10}}  # 输出1
  - {from: [5], module: LinearDyn, args: {out_features: 5}}   # 输出2


In [None]:
stem:
  - {type: "image", channels: 3, height: 32, width: 32}  # 输入图像1
  - {type: "image", channels: 3, height: 32, width: 32}  # 输入图像2

backbone:
  - {from: [0], module: Conv2d, args: {in_channels: 3, out_channels: 16, kernel_size: 3, stride: 1, padding: 1}}
  - {from: [-1], module: ReLU, args: {}}
  - {from: [-1], module: MaxPool2d, args: {kernel_size: 2, stride: 2}}

  - {from: [1], module: Conv2d, args: {in_channels: 3, out_channels: 16, kernel_size: 3, stride: 1, padding: 1}}
  - {from: [-1], module: ReLU, args: {}}
  - {from: [-1], module: MaxPool2d, args: {kernel_size: 2, stride: 2}}

head:
  - {from: [2], module: Flatten, args: {}}
  - {from: [-1], module: Linear, args: {in_features: 16*16*16, out_features: 10}}  # 输出1
  - {from: [5], module: Flatten, args: {}}
  - {from: [-1], module: Linear, args: {in_features: 16*16*16, out_features: 5}}   # 输出2



In [None]:
import torch
import torch.nn as nn
import yaml

def build_model(cfg):
    layers_cfg = cfg['backbone'] + cfg['head']
    from_list = [l['from'] for l in layers_cfg]
    layers = nn.ModuleList()
    
    for l in layers_cfg:
        module_cls = getattr(nn, l['module'])
        layers.append(module_cls(**l['args']))
    
    def forward_fn(inputs):
        outputs = {i: inp for i, inp in enumerate(inputs)}
        results = []
        for i, layer in enumerate(layers):
            f = from_list[i]
            f = f if isinstance(f, list) else [f]
            x = torch.cat([outputs[j] for j in f], dim=1) if len(f) > 1 else outputs[f[0]]
            out = layer(x)
            outputs[len(outputs)] = out
            # 收集 head 输出
            if isinstance(layer, nn.Linear):
                results.append(out)
        return results

    model = nn.Module()
    model.layers = layers
    model.forward = forward_fn
    return model

# ----------------------------
# 测试流程
# ----------------------------
if __name__ == "__main__":
    with open("simple_twoinput_minimal.yaml") as f:
        cfg = yaml.safe_load(f)

    model = build_model(cfg)
    img1 = torch.randn(2, 3, 32, 32)
    img2 = torch.randn(2, 3, 32, 32)
    out1, out2 = model([img1, img2])

    print("输出1 shape:", out1.shape)  # [2, 10]
    print("输出2 shape:", out2.shape)  # [2, 5]


In [None]:
nc: [10, 5]  # 两个任务的分类数

# stem 定义输入
stem:
  - {type: "image", channels: 3, height: 32, width: 32}  # 输入1
  - {type: "image", channels: 3, height: 32, width: 32}  # 输入2

# backbone 定义中间层
backbone:
  # 输入1
  - {from: [0], module: Conv, args: {out_channels: 16, kernel_size: 3, stride: 1, padding: 1}}
  - {from: [-1], module: ReLU, args: {}}
  - {from: [-1], module: MaxPool, args: {kernel_size: 2, stride: 2}}

  # 输入2
  - {from: [1], module: Conv, args: {out_channels: 16, kernel_size: 3, stride: 1, padding: 1}}
  - {from: [-1], module: ReLU, args: {}}
  - {from: [-1], module: MaxPool, args: {kernel_size: 2, stride: 2}}

  # 拼接输入1和输入2
  - {from: [2,5], module: "Concat", args: {dim: 1}}

head:
  - {from: [-1], module: FC, args: {out_features: 10}}  # 输出1
  - {from: [-2], module: FC, args: {out_features: 5}}   # 输出2
  
import torch
import torch.nn as nn
import yaml

# ----------------------------
# 工厂函数
# ----------------------------
def module_factory(module_name, args, prev_out=None):
    """仅生成可训练模块，Concat 不在此处理"""
    if module_name == "Conv":
        in_ch = prev_out.shape[1] if prev_out is not None else args.get('in_channels')
        return nn.Conv2d(in_channels=in_ch, **{k: v for k, v in args.items() if k != 'in_channels'})
    
    elif module_name == "BatchNorm":
        num_features = prev_out.shape[1] if prev_out is not None else args['num_features']
        return nn.BatchNorm2d(num_features)
    
    elif module_name == "ReLU":
        return nn.ReLU(inplace=True)
    
    elif module_name == "MaxPool":
        return nn.MaxPool2d(**args)
    
    elif module_name == "GlobalAvgPool":
        return nn.AdaptiveAvgPool2d((1, 1))
    
    elif module_name == "FC":
        in_features = prev_out.numel() // prev_out.shape[0] if prev_out is not None else args['in_features']
        return nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features, args['out_features'])
        )
    
    elif module_name == "ResBlock":
        return ResBlock(**args)  # 自定义模块
    
    else:
        raise ValueError(f"Unknown module: {module_name}")

# ----------------------------
# 构建模型
# ----------------------------
def build_model(cfg):
    layers_cfg = cfg['backbone'] + cfg['head']
    from_list = [l['from'] for l in layers_cfg]

    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            self.layers_cfg = layers_cfg
            self.from_list = from_list
            self.layers = nn.ModuleList()
            for cfg in layers_cfg:
                if cfg['module'] != 'Concat':
                    # 用 Identity 占位，真实层在 forward 动态生成
                    self.layers.append(nn.Identity())
                else:
                    self.layers.append(None)

        def forward(self, inputs):
            outputs = {i: inp for i, inp in enumerate(inputs)}
            results = []
            for i, cfg in enumerate(self.layers_cfg):
                f = self.from_list[i]
                f = f if isinstance(f, list) else [f]
                
                # 如果是多输入拼接
                if cfg['module'] == 'Concat':
                    out = torch.cat([outputs[j] for j in f], dim=cfg['args'].get('dim',1))
                else:
                    x = torch.cat([outputs[j] for j in f], dim=1) if len(f) > 1 else outputs[f[0]]
                    layer = module_factory(cfg['module'], cfg['args'], prev_out=x)
                    out = layer(x)
                
                outputs[len(outputs)] = out
                
                # 如果是 FC，认为是 head 输出
                if cfg['module'] == 'FC':
                    results.append(out)
            return results

    return Net()

# ----------------------------
# 测试
# ----------------------------
if __name__ == "__main__":
    with open("multiinput_multioutput_concat.yaml") as f:
        cfg = yaml.safe_load(f)
    
    model = build_model(cfg)
    img1 = torch.randn(2, 3, 32, 32)
    img2 = torch.randn(2, 3, 32, 32)
    out1, out2 = model([img1, img2])

    print("输出1 shape:", out1.shape)  # [2, 10]
    print("输出2 shape:", out2.shape)  # [2, 5]


In [12]:
import torch
import torch.nn as nn
import importlib

def dynamic_class_instantiate_from_string(class_path: str, **kwargs):
    module_name, class_name = class_path.rsplit(".", 1)
    module = importlib.import_module(module_name)
    cls = getattr(module, class_name)
    return cls(**kwargs)

class DynamicNet(nn.Module):
    def __init__(self, yaml_dict):
        super().__init__()
        self.yaml_dict = yaml_dict
        self.stem_count = len(yaml_dict.get("stem", []))
        self.head_idxs = []
        self.layers = nn.ModuleList()
        self.c_out = []  # 保存每一层的输出通道数
        self._init_stem_channels()
        self._build_layers()

    def _init_stem_channels(self):
        """把 stem 的输出通道信息加入 c_out"""
        for stem_def in self.yaml_dict.get("stem", []):
            self.c_out.append(stem_def["channels"])

    def _build_layers(self):
        all_defs = self.yaml_dict.get("backbone", []) + self.yaml_dict.get("head", [])
        for i, layer_def in enumerate(all_defs):
            module_name = layer_def["module"]
            args = dict(layer_def.get("args", {}))  # copy dict
            from_idxs = layer_def["from"]

            # 计算 in_channels，如果模块需要
            if module_name != "Concat" and "in_channels" in dynamic_class_init_params(module_name):
                if "in_channels" not in args or args["in_channels"] is None:
                    if len(from_idxs) == 1:
                        args["in_channels"] = self.c_out[from_idxs[0]]
                    else:
                        # 对于多个输入，拼接维度
                        in_ch = sum(self.c_out[idx] for idx in from_idxs)
                        args["in_channels"] = in_ch

            if module_name != "Concat" and module_name != "FC":
                module = dynamic_class_instantiate_from_string(module_name, **args)
                if hasattr(module, "out_channels"):
                    c2 = module.out_channels
                elif hasattr(module, "weight"):
                    c2 = module.weight.shape[0]
                else:
                    c2 = None
            else:
                module = None
                if module_name == "FC":
                    c2 = args.get("out_features")
                else:
                    c2 = None  # Concat 层在 forward 时计算

            self.layers.append(module)
            self.c_out.append(c2)

            # 记录 head 输出索引
            if layer_def["module"] == "FC":
                self.head_idxs.append(i)


def dynamic_class_init_params(class_path):
    """获取类构造函数参数名"""
    module_name, class_name = class_path.rsplit(".", 1)
    module = importlib.import_module(module_name)
    cls = getattr(module, class_name)
    import inspect
    return inspect.signature(cls).parameters.keys()


In [13]:
yaml_dict = {
    "stem": [
        {"type": "image", "channels": 3, "height": 32, "width": 32},
        {"type": "image", "channels": 3, "height": 32, "width": 32},
    ],
    "backbone": [
        {"from": [0], "module": "torch.nn.Conv2d", "args": {"out_channels": 16, "kernel_size": 3, "stride": 1, "padding": 1}},
        {"from": [-1], "module": "torch.nn.ReLU", "args": {}},
        {"from": [-1], "module": "torch.nn.MaxPool2d", "args": {"kernel_size": 2, "stride": 2}},
        {"from": [1], "module": "torch.nn.Conv2d", "args": {"out_channels": 16, "kernel_size": 3, "stride": 1, "padding": 1}},
        {"from": [-1], "module": "torch.nn.ReLU", "args": {}},
        {"from": [-1], "module": "torch.nn.MaxPool2d", "args": {"kernel_size": 2, "stride": 2}},
        {"from": [2, 5], "module": "Concat", "args": {"dim": 1}},
    ],
    "head": [
        {"from": [-1], "module": "torch.nn.Linear", "args": {"out_features": 10}},
        {"from": [-2], "module": "torch.nn.Linear", "args": {"out_features": 5}},
    ]
}


# 输入张量
model = DynamicNet(yaml_dict)

x1 = torch.randn(1, 3, 32, 32)
x2 = torch.randn(1, 3, 32, 32)

outs = model([x1, x2])

for i, out in enumerate(outs):
    print(f"Head {i} output shape: {out.shape}")


TypeError: Linear.__init__() missing 1 required positional argument: 'in_features'

In [16]:
import torch
import torch.nn as nn
import importlib

def dynamic_class_instantiate_from_string(class_path: str, **kwargs):
    """根据字符串路径动态实例化类"""
    module_name, class_name = class_path.rsplit(".", 1)
    module = importlib.import_module(module_name)
    cls = getattr(module, class_name)
    return cls(**kwargs)

class DynamicNet(nn.Module):
    def __init__(self, yaml_dict):
        super().__init__()
        self.yaml_dict = yaml_dict
        self.layers = nn.ModuleList()
        self.c_out = []  # 保存每一层输出通道
        self.head_idxs = []

        self._init_stem_channels()
        self._build_layers()

    def _init_stem_channels(self):
        stem = self.yaml_dict["stem"]
        assert len(stem) == 1, "当前只支持单输入"
        self.c_out.append(stem[0]["channels"])
        self.input_shape = (1, stem[0]["channels"], stem[0]["height"], stem[0]["width"])

    def _build_layers(self):
        # 使用一个虚拟输入推算 Linear 输入尺寸
        dummy = torch.randn(self.input_shape)

        for i, layer_def in enumerate(self.yaml_dict["backbone"] + self.yaml_dict["head"]):
            module_name = layer_def["module"]
            from_idxs = layer_def["from"]
            args = layer_def.get("args", {})

            # 自动填充 in_channels
            if module_name in ["torch.nn.Conv2d", "torch.nn.BatchNorm2d"]:
                if "in_channels" not in args or args["in_channels"] is None:
                    args["in_channels"] = self.c_out[from_idxs[0]]

            # 对 Linear 自动填充 in_features
            if module_name == "torch.nn.Linear":
                if "in_features" not in args or args["in_features"] is None:
                    with torch.no_grad():
                        inp = dummy
                        out_idxs = from_idxs
                        inp = inp if len(out_idxs) == 1 else torch.cat([dummy for _ in out_idxs], dim=1)
                        inp_flat = torch.flatten(inp, 1)
                        args["in_features"] = inp_flat.shape[1]

            module = (
                dynamic_class_instantiate_from_string(module_name, **args)
                if module_name != "Concat"
                else None
            )

            self.layers.append(module)

            # 前向更新 dummy
            if module is not None:
                inp = dummy if len(from_idxs) == 1 else torch.cat([dummy for _ in from_idxs], dim=1)
                if isinstance(module, nn.Linear) and inp.ndim > 2:
                    inp = torch.flatten(inp, 1)
                dummy = module(inp)
            else:
                # Concat
                dummy = torch.cat([dummy for _ in from_idxs], dim=1)

            # 更新输出通道
            if module_name == "Concat":
                self.c_out.append(sum(self.c_out[idx] for idx in from_idxs))
            elif module_name == "torch.nn.Linear":
                self.c_out.append(args["out_features"])
            elif module_name in ["torch.nn.Conv2d", "torch.nn.BatchNorm2d"]:
                self.c_out.append(args.get("out_channels", self.c_out[from_idxs[0]]))
            else:
                self.c_out.append(self.c_out[from_idxs[0]])

            if i >= len(self.yaml_dict["backbone"]):
                self.head_idxs.append(i)

    def forward(self, x):
        outputs = [x]
        for i, layer in enumerate(self.layers):
            layer_def = (self.yaml_dict["backbone"] + self.yaml_dict["head"])[i]
            from_idxs = layer_def["from"]

            if layer is not None:
                inp = outputs[from_idxs[0]] if len(from_idxs) == 1 else torch.cat([outputs[idx] for idx in from_idxs], dim=1)
                if isinstance(layer, nn.Linear) and inp.ndim > 2:
                    inp = torch.flatten(inp, 1)
                out = layer(inp)
            else:
                # Concat
                out = torch.cat([outputs[idx] for idx in from_idxs], dim=1)

            outputs.append(out)

        return outputs[self.head_idxs[0]]

# --------------------------
# YAML 字典
# --------------------------
yaml_dict = {
    "stem": [{"type": "image", "channels": 3, "height": 32, "width": 32}],
    "backbone": [
        {"from": [0], "module": "torch.nn.Conv2d", "args": {"out_channels": 16, "kernel_size": 3, "stride": 1, "padding": 1}},
        {"from": [-1], "module": "torch.nn.ReLU", "args": {}},
        {"from": [-1], "module": "torch.nn.MaxPool2d", "args": {"kernel_size": 2, "stride": 2}},
    ],
    "head": [
        {"from": [-1], "module": "torch.nn.Linear", "args": {"out_features": 10}}
    ]
}

# --------------------------
# 测试
# --------------------------
x = torch.randn(1, 3, 32, 32)
model = DynamicNet(yaml_dict)
out = model(x)
print(out.shape)  # torch.Size([1, 10])


torch.Size([1, 16, 16, 16])


In [17]:
import torch
import torch.nn as nn
import importlib

def dynamic_class_instantiate_from_string(class_path: str, **kwargs):
    """根据字符串路径动态实例化类"""
    module_name, class_name = class_path.rsplit(".", 1)
    module = importlib.import_module(module_name)
    cls = getattr(module, class_name)
    return cls(**kwargs)


class DynamicNet(nn.Module):
    def __init__(self, yaml_dict):
        super().__init__()
        self.yaml_dict = yaml_dict
        self.layers_def = yaml_dict.get("backbone", []) + yaml_dict.get("head", [])
        self.layers = nn.ModuleList()
        self.layer_outputs = []  # 用于存储每层输出通道数或特征数

        self._build_layers()

    def _build_layers(self):
        for i, layer_def in enumerate(self.layers_def):
            module_name = layer_def["module"]
            args = layer_def.get("args", {}).copy()
            from_idxs = layer_def.get("from", [-1])
            
            # 处理in_channels/in_features
            if module_name not in ["Concat"]:
                if "in_channels" in args and args["in_channels"] is None:
                    # 线性或卷积层需要推算输入通道
                    args["in_channels"] = self.layer_outputs[from_idxs[0]]
                if "in_features" in args and args["in_features"] is None:
                    args["in_features"] = self.layer_outputs[from_idxs[0]]
                
                layer = dynamic_class_instantiate_from_string(module_name, **args)
            else:
                layer = None  # Concat 在 forward 中处理

            self.layers.append(layer)

            # 更新 layer_outputs，卷积使用 out_channels，Linear使用 out_features
            if module_name == "Concat":
                ch = sum(self.layer_outputs[idx] for idx in from_idxs)
            elif module_name.endswith("Conv2d"):
                ch = args["out_channels"]
            elif module_name.endswith("Linear"):
                ch = args["out_features"]
            else:
                # 对于 ReLU、MaxPool 等不改变通道数
                ch = self.layer_outputs[from_idxs[0]]
            self.layer_outputs.append(ch)

        # 保存 head 的索引
        self.head_idxs = list(range(len(self.layer_outputs) - len(self.yaml_dict.get("head", [])), len(self.layer_outputs)))

    def forward(self, x):
        outputs = []
        outputs.append(x)  # stem 输出
        for i, (layer_def, layer) in enumerate(zip(self.layers_def, self.layers)):
            from_idxs = layer_def.get("from", [-1])
            inputs = [outputs[idx + 1] for idx in from_idxs]  # +1 因为 stem 占位
            if layer_def["module"] == "Concat":
                out = torch.cat(inputs, dim=layer_def["args"]["dim"])
            else:
                inp = inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=1)
                # 如果是 Linear，先 flatten
                if isinstance(layer, nn.Linear) and inp.ndim > 2:
                    inp = torch.flatten(inp, 1)
                out = layer(inp)
            outputs.append(out)
        
        head_outputs = [outputs[idx + 1] for idx in self.head_idxs]
        return head_outputs[0] if len(head_outputs) == 1 else head_outputs


# --------------------------
# 示例 YAML 字典
# --------------------------
yaml_dict = {
    "stem": [
        {"type": "image", "channels": 3, "height": 32, "width": 32}
    ],
    "backbone": [
        {"from": [0], "module": "torch.nn.Conv2d", "args": {"in_channels": 3, "out_channels": 16, "kernel_size": 3, "stride": 1, "padding": 1}},
        {"from": [-1], "module": "torch.nn.ReLU", "args": {}},
        {"from": [-1], "module": "torch.nn.MaxPool2d", "args": {"kernel_size": 2, "stride": 2}},
    ],
    "head": [
        {"from": [-1], "module": "torch.nn.Linear", "args": {"in_features": None, "out_features": 10}}
    ]
}

# --------------------------
# 测试
# --------------------------
x = torch.randn(1, 3, 32, 32)
model = DynamicNet(yaml_dict)
out = model(x)
print(out.shape)  # torch.Size([1, 10])


IndexError: list index out of range

In [1]:
import torch
import torch.nn as nn
import importlib

def dynamic_class_instantiate_from_string(class_path: str, **kwargs):
    """根据字符串路径动态实例化类"""
    module_name, class_name = class_path.rsplit(".", 1)
    module = importlib.import_module(module_name)
    cls = getattr(module, class_name)
    return cls(**kwargs)


class ModelBuilder(nn.Module):
    def __init__(self, yaml_dict):
        super().__init__()
        self.yaml_dict = yaml_dict
        self.layers = nn.ModuleList()
        self.output_shapes = {}
        self.input_shape = None
        self.build_model()

    def build_model(self):
        # 1. stem: 定义输入
        stem = self.yaml_dict.get("stem", [])
        if stem:
            input_cfg = stem[0]
            self.input_shape = (input_cfg["channels"], input_cfg["height"], input_cfg["width"])
        
        # 2. backbone + head
        self.parse_layers("backbone")
        self.parse_layers("head")

    def parse_layers(self, section):
        for layer_cfg in self.yaml_dict.get(section, []):
            from_idx = layer_cfg.get("from", [-1])[0]
            args = dict(layer_cfg["args"])  # 复制，避免修改原 dict
            
            # 自动推断 in_features
            if "in_features" in args and args["in_features"] is None:
                prev_shape = self.output_shapes[from_idx]
                args["in_features"] = prev_shape[0] * prev_shape[1] * prev_shape[2]

            # 动态实例化模块
            layer = dynamic_class_instantiate_from_string(layer_cfg["module"], **args)
            self.layers.append(layer)

            # 前向计算一次 shape（dummy input）
            dummy_input = torch.zeros((1, *self.input_shape))
            for i, l in enumerate(self.layers):
                if isinstance(l, nn.Linear):
                    dummy_input = dummy_input.view(dummy_input.size(0), -1)
                dummy_input = l(dummy_input)
                self.output_shapes[i] = dummy_input.shape[1:]

    def forward(self, x):
        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                x = x.view(x.size(0), -1)
            x = layer(x)
        return x


In [3]:
yaml_dict = {
    "stem": [
        {"type": "image", "channels": 3, "height": 32, "width": 32}
    ],
    "backbone": [
        {"from": [0], "module": "torch.nn.Conv2d", "args": {"in_channels": 3, "out_channels": 16, "kernel_size": 3, "stride": 1, "padding": 1}},
        {"from": [-1], "module": "torch.nn.ReLU", "args": {}},
        {"from": [-1], "module": "torch.nn.MaxPool2d", "args": {"kernel_size": 2, "stride": 2}},
    ],
    "head": [
        {"from": [-1], "module": "torch.nn.Linear", "args": {"in_features": None, "out_features": 10}}
    ]
}

model = ModelBuilder(yaml_dict)
print(model)

x = torch.randn(1, 3, 32, 32)
y = model(x)
print("Output shape:", y.shape)


KeyError: -1

In [1]:
import torch
import torch.nn as nn

def dynamic_class_instantiate_from_string(class_path, **kwargs):
    """
    简单动态实例化类，例如 "torch.nn.Conv2d"
    """
    parts = class_path.split(".")
    module_name, class_name = ".".join(parts[:-1]), parts[-1]
    module = __import__(module_name, fromlist=[class_name])
    cls = getattr(module, class_name)
    return cls(**kwargs)

class DAGNet(nn.Module):
    def __init__(self, layers_config, input_name="input"):
        """
        layers_config: dict
            key: 层名
            value: dict, 包含 'module' 和 'args' 和 'from'
            例如：
            {
                "conv1": {"module": "torch.nn.Conv2d", "args": {"in_channels": 3, "out_channels":16, "kernel_size":3, "padding":1}, "from": ["input"]},
                "relu1": {"module": "torch.nn.ReLU", "args": {}, "from": ["conv1"]},
                "conv2": {"module": "torch.nn.Conv2d", "args": {"in_channels":16, "out_channels":32, "kernel_size":3, "padding":1}, "from": ["relu1"]},
                "out": {"module": "torch.nn.Linear", "args": {"in_features":32*32*32, "out_features":10}, "from": ["conv2"]}
            }
        """
        super().__init__()
        self.input_name = input_name
        self.layers = nn.ModuleDict()
        self.edges = {}  # 保存每层输入

        for name, conf in layers_config.items():
            self.layers[name] = dynamic_class_instantiate_from_string(conf["module"], **conf.get("args", {}))
            self.edges[name] = conf["from"]

    def forward(self, x):
        outputs = {self.input_name: x}
        for name, layer in self.layers.items():
            input_names = self.edges[name]
            # 多输入 concat
            inputs = [outputs[i] for i in input_names]
            if len(inputs) == 1:
                out = layer(inputs[0])
            else:
                # 默认 dim=1 concat
                out = layer(torch.cat(inputs, dim=1))
            outputs[name] = out
        # 返回最后一层输出
        return out

# --------------------------
# 测试
# --------------------------
layers_config = {
    "conv1": {"module": "torch.nn.Conv2d", "args": {"in_channels": 3, "out_channels":16, "kernel_size":3, "padding":1}, "from": ["input"]},
    "relu1": {"module": "torch.nn.ReLU", "args": {}, "from": ["conv1"]},
    "conv2": {"module": "torch.nn.Conv2d", "args": {"in_channels":16, "out_channels":32, "kernel_size":3, "padding":1}, "from": ["relu1"]},
    "flatten": {"module": "torch.nn.Flatten", "args": {}, "from": ["conv2"]},
    "fc": {"module": "torch.nn.Linear", "args": {"in_features":32*32*32, "out_features":10}, "from": ["flatten"]}
}

x = torch.randn(1, 3, 32, 32)
model = DAGNet(layers_config)
out = model(x)
print(out.shape)  # torch.Size([1, 10])


torch.Size([1, 10])


In [4]:
import torch
import torch.nn as nn

class DAGNet(nn.Module):
    def __init__(self, layers_config):
        """
        layers_config: dict
            key: layer name
            value: dict {
                "module": callable or str,
                "args": dict,
                "from": list of input layer names
            }
        """
        super().__init__()
        self.layers_config = layers_config
        self.layers = nn.ModuleDict()
        self.input_names = []
        self.output_names = []

        # 实例化每一层
        for name, cfg in layers_config.items():
            module = cfg["module"]
            args = cfg.get("args", {})
            # 支持直接传 module 类或 str 名
            if isinstance(module, str):
                module = eval(module)
            self.layers[name] = module(**args)

    def forward(self, x):
        """
        x: list or tuple of input tensors, 顺序对应 self.input_names
        """
        if not self.input_names:
            raise ValueError("self.input_names must be set before forward")

        outputs = {}
        # 初始化输入
        for i, name in enumerate(self.input_names):
            outputs[name] = x[i]

        # 遍历所有层
        for name, cfg in self.layers_config.items():
            from_layers = cfg.get("from", [])
            layer = self.layers[name]

            # 获取输入张量
            if len(from_layers) == 0:
                inp = outputs[self.input_names[0]]  # 没有指定来源，用第一个输入
            elif len(from_layers) == 1:
                inp = outputs[from_layers[0]]
            else:
                # 多输入 concat
                inp = torch.cat([outputs[f] for f in from_layers], dim=1)

            outputs[name] = layer(inp)

        # 返回输出
        if not self.output_names:
            # 默认返回最后一层
            return outputs[name]
        else:
            return tuple(outputs[name] for name in self.output_names)

# --------------------------
# 示例：多输入、多输出
# --------------------------
layers_config = {
    "conv1": {"module": "nn.Conv2d", "args": {"in_channels":3, "out_channels":16, "kernel_size":3, "padding":1}, "from":["img1"]},
    "conv2": {"module": "nn.Conv2d", "args": {"in_channels":3, "out_channels":16, "kernel_size":3, "padding":1}, "from":["img2"]},
    "concat": {"module": "nn.Conv2d", "args": {"in_channels":32, "out_channels":32, "kernel_size":3, "padding":1}, "from":["conv1","conv2"]},
    "flatten": {"module": "nn.Flatten", "args": {}, "from":["concat"]},
    "fc": {"module": "nn.Linear", "args": {"in_features":32*32*32, "out_features":10}, "from":["flatten"]}
}

x1 = torch.randn(1,3,32,32)
x2 = torch.randn(1,3,32,32)

model = DAGNet(layers_config)
model.input_names = ["img1","img2"]
model.output_names = ["fc"]

out = model([x1, x2])
print(out)
print(out[0].shape)
# print(out[1].shape)
# print(out.shape)  # torch.Size([1, 10])


(tensor([[-0.0665, -0.2501, -0.0827, -0.0027,  0.3577, -0.0009, -0.2966, -0.0463,
         -0.1123,  0.3288]], grad_fn=<AddmmBackward0>),)
torch.Size([1, 10])


In [3]:
import torch
import torch.nn as nn

class DAGNet(nn.Module):
    def __init__(self, layers_config):
        """
        layers_config: list of dict
            每个元素是一个层的配置：
            {
                "name": str,   # 层名字
                "module": callable or str,
                "args": dict,
                "from": list of input layer names
            }
        """
        super().__init__()
        self.layers_config = layers_config
        self.layers = nn.ModuleDict()
        self.input_names = []
        self.output_names = []

        # 实例化每一层
        for cfg in layers_config:
            name = cfg["name"]
            module = cfg["module"]
            args = cfg.get("args", {})
            # 支持直接传 module 类或 str 名
            if isinstance(module, str):
                module = eval(module)
            self.layers[name] = module(**args)

    def forward(self, x):
        """
        x: list or tuple of input tensors, 顺序对应 self.input_names
        """
        if not self.input_names:
            raise ValueError("self.input_names must be set before forward")

        outputs = {}
        # 初始化输入
        for i, name in enumerate(self.input_names):
            outputs[name] = x[i]

        # 遍历所有层（按顺序）
        for cfg in self.layers_config:
            name = cfg["name"]
            from_layers = cfg.get("from", [])
            layer = self.layers[name]

            # 获取输入张量
            if len(from_layers) == 0:
                inp = outputs[self.input_names[0]]  # 默认用第一个输入
            elif len(from_layers) == 1:
                inp = outputs[from_layers[0]]
            else:
                # 多输入 concat
                inp = torch.cat([outputs[f] for f in from_layers], dim=1)

            outputs[name] = layer(inp)

        # 返回输出
        if not self.output_names:
            # 默认返回最后一层
            return outputs[name]
        else:
            return tuple(outputs[name] for name in self.output_names)


# --------------------------
# 示例：多输入、多输出
# --------------------------
layers_config = [
    {"name":"conv1", "module":"nn.Conv2d", "args":{"in_channels":3, "out_channels":16, "kernel_size":3, "padding":1}, "from":["img1"]},
    {"name":"conv2", "module":"nn.Conv2d", "args":{"in_channels":3, "out_channels":16, "kernel_size":3, "padding":1}, "from":["img2"]},
    {"name":"concat", "module":"nn.Conv2d", "args":{"in_channels":32, "out_channels":32, "kernel_size":3, "padding":1}, "from":["conv1","conv2"]},
    {"name":"flatten", "module":"nn.Flatten", "args":{}, "from":["concat"]},
    {"name":"fc", "module":"nn.Linear", "args":{"in_features":32*32*32, "out_features":10}, "from":["flatten"]}
]

x1 = torch.randn(1,3,32,32)
x2 = torch.randn(1,3,32,32)

model = DAGNet(layers_config)
model.input_names = ["img1","img2"]
model.output_names = ["fc"]

out = model([x1, x2])
print(out)              # tuple，因为 output_names 是列表
print(out[0].shape)     # torch.Size([1, 10])


(tensor([[-0.1884, -0.0595, -0.0263,  0.1358, -0.0650, -0.0262,  0.2738,  0.0362,
         -0.0309,  0.0726]], grad_fn=<AddmmBackward0>),)
torch.Size([1, 10])


In [4]:
import torch
import torch.nn as nn

class DAGNet(nn.Module):
    def __init__(self, config):
        """
        config: dict
            {
                "input_names": [...],
                "output_names": [...],
                "layers": [
                    {"name":..., "module":..., "args":..., "from":[...]},
                    ...
                ]
            }
        """
        super().__init__()
        self.input_names = config.get("input_names", [])
        self.output_names = config.get("output_names", [])
        self.layers_config = config["layers"]
        self.layers = nn.ModuleDict()

        # 实例化每一层
        for cfg in self.layers_config:
            name = cfg["name"]
            module = cfg["module"]
            args = cfg.get("args", {})
            if isinstance(module, str):
                module = eval(module)
            self.layers[name] = module(**args)

    def forward(self, x):
        if not self.input_names:
            raise ValueError("self.input_names must be set before forward")

        outputs = {}
        # 初始化输入
        for i, name in enumerate(self.input_names):
            outputs[name] = x[i]

        # 按顺序遍历每层
        for cfg in self.layers_config:
            name = cfg["name"]
            from_layers = cfg.get("from", [])
            layer = self.layers[name]

            if len(from_layers) == 0:
                inp = outputs[self.input_names[0]]
            elif len(from_layers) == 1:
                inp = outputs[from_layers[0]]
            else:
                inp = torch.cat([outputs[f] for f in from_layers], dim=1)

            outputs[name] = layer(inp)

        # 返回输出
        if not self.output_names:
            return outputs[name]
        else:
            return tuple(outputs[name] for name in self.output_names)


# --------------------------
# 示例配置
# --------------------------
model_config = {
    "input_names": ["img1", "img2"],
    "output_names": ["fc"],
    "layers": [
        {"name":"conv1", "module":"nn.Conv2d", "args":{"in_channels":3, "out_channels":16, "kernel_size":3, "padding":1}, "from":["img1"]},
        {"name":"conv2", "module":"nn.Conv2d", "args":{"in_channels":3, "out_channels":16, "kernel_size":3, "padding":1}, "from":["img2"]},
        {"name":"concat", "module":"nn.Conv2d", "args":{"in_channels":32, "out_channels":32, "kernel_size":3, "padding":1}, "from":["conv1","conv2"]},
        {"name":"flatten", "module":"nn.Flatten", "args":{}, "from":["concat"]},
        {"name":"fc", "module":"nn.Linear", "args":{"in_features":32*32*32, "out_features":10}, "from":["flatten"]}
    ]
}

x1 = torch.randn(1,3,32,32)
x2 = torch.randn(1,3,32,32)

model = DAGNet(model_config)
out = model([x1, x2])
print(out)
print(out[0].shape)


(tensor([[-0.3555,  0.1723,  0.0846, -0.0868,  0.2804, -0.1375, -0.1838,  0.2455,
         -0.1618,  0.3790]], grad_fn=<AddmmBackward0>),)
torch.Size([1, 10])


In [5]:
import torch
import torch.nn as nn

class DAGNet(nn.Module):
    def __init__(self, config):
        """
        config: dict
            input_nodes: list of dict {"name": str, "shape": tuple}
            output_nodes: list of dict {"name": str, "shape": tuple}
            layers: list of dict {
                "name": str,
                "module": nn.Module class or str,
                "args": dict,
                "from": list of input layer names
            }
        """
        super().__init__()
        self.input_nodes = config.get("input_nodes", [])
        self.output_nodes = config.get("output_nodes", [])
        self.layers_config = config["layers"]
        self.layers = nn.ModuleDict()

        # 实例化每一层
        for cfg in self.layers_config:
            name = cfg["name"]
            module = cfg["module"]
            args = cfg.get("args", {})
            if isinstance(module, str):
                module = eval(module)
            self.layers[name] = module(**args)

    def forward(self, x):
        outputs = {}
        # 初始化输入并检查形状
        for i, node in enumerate(self.input_nodes):
            inp = x[i]
            expected_shape = node.get("shape")
            if expected_shape and inp.shape[1:] != expected_shape:
                raise ValueError(f"Input {node['name']} expected shape {expected_shape}, got {inp.shape[1:]}")
            outputs[node["name"]] = inp

        # 遍历所有层
        for cfg in self.layers_config:
            name = cfg["name"]
            from_layers = cfg.get("from", [])
            layer = self.layers[name]

            # 获取输入张量
            if len(from_layers) == 0:
                inp = outputs[self.input_nodes[0]["name"]]
            elif len(from_layers) == 1:
                inp = outputs[from_layers[0]]
            else:
                # 多输入 concat
                inp = torch.cat([outputs[f] for f in from_layers], dim=1)

            # 对 Linear 层自动计算 in_features
            if isinstance(layer, nn.Linear) and 'in_features' not in layer.__dict__:
                layer.in_features = inp.numel() // inp.shape[0]
                layer.weight = nn.Parameter(torch.empty(layer.out_features, layer.in_features))
                layer.bias = nn.Parameter(torch.empty(layer.out_features))

            outputs[name] = layer(inp)

        # 返回输出
        if not self.output_nodes:
            return outputs[name]
        else:
            return tuple(outputs[n["name"]] for n in self.output_nodes)


# --------------------------
# 示例配置
# --------------------------
config = {
    "input_nodes": [
        {"name": "img1", "shape": (3, 32, 32)},
        {"name": "img2", "shape": (3, 32, 32)}
    ],
    "output_nodes": [
        {"name": "fc", "shape": (10,)}
    ],
    "layers": [
        {"name": "conv1", "module": "nn.Conv2d", "args": {"in_channels":3, "out_channels":16, "kernel_size":3, "padding":1}, "from":["img1"]},
        {"name": "conv2", "module": "nn.Conv2d", "args": {"in_channels":3, "out_channels":16, "kernel_size":3, "padding":1}, "from":["img2"]},
        {"name": "concat", "module": "nn.Conv2d", "args": {"in_channels":32, "out_channels":32, "kernel_size":3, "padding":1}, "from":["conv1","conv2"]},
        {"name": "flatten", "module": "nn.Flatten", "args": {}, "from":["concat"]},
        {"name": "fc", "module": "nn.Linear", "args": {"in_features":32*32*32, "out_features":10}, "from":["flatten"]}
    ]
}

# --------------------------
# 测试
# --------------------------
x1 = torch.randn(1,3,32,32)
x2 = torch.randn(1,3,32,32)

model = DAGNet(config)
out = model([x1, x2])

print("输出张量:", out)
print("输出形状:", out[0].shape)  # torch.Size([1, 10])


输出张量: (tensor([[ 0.0771,  0.0575, -0.2854, -0.0025,  0.0203, -0.2220,  0.2000,  0.0918,
          0.2136, -0.1407]], grad_fn=<AddmmBackward0>),)
输出形状: torch.Size([1, 10])
