In [37]:
import torch
import torch.nn as nn


# エンコーダーモデルの定義
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=2, padding=1)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(16 * 14 * 14, 128)
        self.fc2 = nn.Linear(16 * 14 * 14, 64)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        output1 = self.fc1(x)
        output2 = self.fc2(x)
        return output1, output2


# デコーダーモデル1の定義
class Decoder1(nn.Module):
    def __init__(self):
        super(Decoder1, self).__init__()
        self.fc = nn.Linear(192, 10)  # 128 + 64 = 192

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        x = self.fc(x)
        return x


# デコーダーモデル2の定義
class Decoder2(nn.Module):
    def __init__(self):
        super(Decoder2, self).__init__()
        self.fc1 = nn.Linear(192, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x1, x2):
        x = torch.cat((x1, x2), dim=1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [38]:
import torch.onnx


encoder = Encoder()
decoder1 = Decoder1()
decoder2 = Decoder2()


dummy_input = torch.randn(1, 1, 28, 28)
dummy_output1, dummy_output2 = encoder(dummy_input)

torch.onnx.export(
    encoder, dummy_input, "encoder.onnx", input_names=["input"], output_names=["encoder_output1", "encoder_output2"]
)
torch.onnx.export(
    decoder1,
    (dummy_output1, dummy_output2),
    "decoder1.onnx",
    input_names=["encoder_output1", "encoder_output2"],
    output_names=["output_dec1"],
)
torch.onnx.export(
    decoder2,
    (dummy_output1, dummy_output2),
    "decoder2.onnx",
    input_names=["encoder_output1", "encoder_output2"],
    output_names=["output_dec2"],
)

In [32]:
import sclblonnx as so

# エンコーダーとデコーダーのONNXモデルを読み込む
encoder_model = so.graph_from_file("encoder.onnx")
decoder_model_1 = so.graph_from_file("decoder1.onnx")
decoder_model_2 = so.graph_from_file("decoder2.onnx")

In [33]:
# プレフィックスを追加


def add_prefix_to_nodes(graph, prefix, exclude_names):
    def add_prefix(name):
        if name in exclude_names:
            return name
        if name.startswith("/"):
            return f"/{prefix}{name}"
        else:
            return f"/{prefix}/{name}"

    for node in graph.node:
        node.name = add_prefix(node.name)
        node.input[:] = [add_prefix(inp) for inp in node.input]
        node.output[:] = [add_prefix(out) for out in node.output]
    for init in graph.initializer:
        init.name = add_prefix(init.name)
    for input in graph.input:
        input.name = add_prefix(input.name)
    for output in graph.output:
        output.name = add_prefix(output.name)
    return graph


# エンコーダーモデルの入力と出力の名前を取得
encoder_inputs_outputs = [input.name for input in encoder_model.input] + [
    output.name for output in encoder_model.output
]

# プレフィックスを追加
decoder_model_1 = add_prefix_to_nodes(decoder_model_1, "decoder1", encoder_inputs_outputs)
decoder_model_2 = add_prefix_to_nodes(decoder_model_2, "decoder2", encoder_inputs_outputs)

In [34]:
# デコーダーの入力をエンコーダーの出力に接続
encoder_output1 = encoder_model.output[0].name
encoder_output2 = encoder_model.output[1].name
decoder_model_1_input1 = decoder_model_1.input[0].name
decoder_model_1_input2 = decoder_model_1.input[1].name
decoder_model_2_input1 = decoder_model_2.input[0].name
decoder_model_2_input2 = decoder_model_2.input[1].name

decoder_model_1 = so.rename_input(decoder_model_1, decoder_model_1_input1, encoder_output1)
decoder_model_1 = so.rename_input(decoder_model_1, decoder_model_1_input2, encoder_output2)
decoder_model_2 = so.rename_input(decoder_model_2, decoder_model_2_input1, encoder_output1)
decoder_model_2 = so.rename_input(decoder_model_2, decoder_model_2_input2, encoder_output2)

In [35]:
# 全てのモデルを結合
combined_model = so.merge(encoder_model, decoder_model_1, io_match=[(encoder_output1, encoder_output1), (encoder_output2, encoder_output2)])
combined_model = so.merge(combined_model, decoder_model_2, io_match=[(encoder_output1, encoder_output1), (encoder_output2, encoder_output2)])

# 結合ONNXモデルを保存
so.graph_to_file(combined_model, "combined_model.onnx")
print("Encoder and Decoders have been successfully merged into a single ONNX model.")

Renaming node names in graph.
Matching specified inputs and outputs..
Pasting graphs.
Running Scailable specific checks for WASM conversion. 
Use _sclbl_check=False to turn off
Your graph was successfully checked.
Renaming node names in graph.
Matching specified inputs and outputs..
Pasting graphs.
Running Scailable specific checks for WASM conversion. 
Use _sclbl_check=False to turn off
Your graph was successfully checked.
Encoder and Decoders have been successfully merged into a single ONNX model.


In [36]:
# 結合ONNXモデルを保存
so.graph_to_file(combined_model, "combined_model.onnx")
print("Encoder and Decoders have been successfully merged into a single ONNX model.")

Encoder and Decoders have been successfully merged into a single ONNX model.
