In [22]:
"""
Fractal Network Configuration
"""

from dataclasses import dataclass
from typing import List, Optional


@dataclass
class FractalConfig:
    """Fractal Networkのハイパーパラメータ設定"""

    # 定数パラメータ
    max_depth: int = 3
    input_dim: int = 128
    output_dim: int = 10
    num_iterations: int = 10  # PrototypeMatchingの反復回数
    gate_momentum: float = 0.5  # ゲート更新の運動量係数
    use_ffn: bool = True  # BranchAttentionでFFN使用
    attention_heads: int = 8  # Decoder側のMulti-head attention数

    # depthごとのパラメータ（リスト形式）
    head_dims: List[int] = None  # 各depthの分岐数 H
    proto_dims: List[int] = None  # 各depthのプロトタイプ次元 D
    proto_nums: List[int] = None  # 各depthのプロトタイプ数 T

    def __post_init__(self):
        """デフォルト値の設定と検証"""
        # デフォルト値設定
        if self.head_dims is None:
            self.head_dims = [8] * self.max_depth
        if self.proto_dims is None:
            self.proto_dims = [64] * self.max_depth
        if self.proto_nums is None:
            self.proto_nums = [16] * self.max_depth

        # 長さの検証
        assert len(self.head_dims) == self.max_depth, \
            f"head_dims length {len(self.head_dims)} != max_depth {self.max_depth}"
        assert len(self.proto_dims) == self.max_depth, \
            f"proto_dims length {len(self.proto_dims)} != max_depth {self.max_depth}"
        assert len(self.proto_nums) == self.max_depth, \
            f"proto_nums length {len(self.proto_nums)} != max_depth {self.max_depth}"

    def get_ffn_hidden_dim(self, depth: int) -> int:
        """FFNの中間次元を取得（入力次元の2倍）"""
        return self.proto_dims[depth] * 2

    def get_input_dim_for_depth(self, depth: int) -> int:
        """各depthの入力次元を取得"""
        if depth == 0:
            return self.input_dim
        else:
            return self.proto_dims[depth - 1]


# デフォルト設定の例
def get_default_config():
    """デフォルト設定を返す"""
    return FractalConfig(
        max_depth=3,
        input_dim=128,
        output_dim=10,
        num_iterations=10,
        gate_momentum=0.5,
        use_ffn=True,
        attention_heads=8,
        head_dims=[8, 16, 32],
        proto_dims=[64, 128, 256],
        proto_nums=[16, 32, 64]
    )


# # 使用例
# if __name__ == "__main__":
#     config = get_default_config()
#     print(f"Config created with max_depth={config.max_depth}")
#     print(f"FFN hidden dim at depth 0: {config.get_ffn_hidden_dim(0)}")
#     print(f"Input dim for depth 1: {config.get_input_dim_for_depth(1)}")

In [23]:
"""
Fractal Encoder Module
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, Optional
from dataclasses import dataclass


@dataclass
class NodeID:
    """ノードを識別するためのデータクラス"""
    depth: int
    parent_id: Optional[Tuple] = None
    node_id: int = 0

    def to_key(self):
        """辞書のキーとして使用"""
        return (self.depth, self.parent_id, self.node_id)


class PrototypeMatching(nn.Module):
    """プロトタイプベースのマッチングメカニズム（PaQ/PaK）"""

    def __init__(self, input_dim: int, proto_dim: int, proto_num: int, head_dim: int):
        """
        Args:
            input_dim: 入力次元
            proto_dim: プロトタイプ次元 D
            proto_num: プロトタイプ数 T
            head_dim: ヘッド数（分岐数）H
        """
        super().__init__()
        self.input_dim = input_dim
        self.proto_dim = proto_dim
        self.proto_num = proto_num
        self.head_dim = head_dim

        # 重み行列（PaQ用）
        self.Wxq = nn.Parameter(torch.randn(head_dim, input_dim, proto_dim))
        self.Wxk = nn.Parameter(torch.randn(head_dim, input_dim, proto_dim))
        self.Wxv = nn.Parameter(torch.randn(head_dim, input_dim, proto_dim))
        self.Wpq = nn.Parameter(torch.randn(head_dim, proto_num, proto_dim, proto_dim))
        self.Wpk = nn.Parameter(torch.randn(head_dim, proto_num, proto_dim, proto_dim))

        # プロトタイプベクトル（学習可能）
        self.prototypes = nn.Parameter(torch.randn(head_dim, proto_num, proto_dim))

        self._init_weights()

    def _init_weights(self):
        """重みの初期化"""
        for param in [self.Wxq, self.Wxk, self.Wxv]:
            nn.init.xavier_uniform_(param)
        for param in [self.Wpq, self.Wpk]:
            nn.init.xavier_uniform_(param.view(self.head_dim * self.proto_num, -1))
        nn.init.xavier_uniform_(self.prototypes.view(self.head_dim * self.proto_num, -1))

    def forward(self, x: torch.Tensor, gate: torch.Tensor,
                num_iterations: int = 10, momentum: float = 0.5) -> torch.Tensor:
        """
        Args:
            x: 入力 [batch, input_dim]
            gate: ゲート初期値 [batch, head_dim, proto_num]
            num_iterations: 反復回数
            momentum: ゲート更新の運動量

        Returns:
            出力 [batch, head_dim, proto_dim]
        """
        batch_size = x.shape[0]

        # 入力を拡張 [batch, head_dim, input_dim]
        x_expanded = x.unsqueeze(1).expand(-1, self.head_dim, -1)

        # プロトタイプとゲートを初期化
        P = self.prototypes.unsqueeze(0).expand(batch_size, -1, -1, -1)  # [batch, H, T, D]
        G = gate  # [batch, H, T]

        # 反復的な更新
        for _ in range(num_iterations):
            # PaK: ゲート更新
            Qx = torch.einsum('hid,bhi->bhd', self.Wxq, x_expanded)  # [batch, H, D]
            Kp = torch.einsum('htdd,bhtd->bhtd', self.Wpk, P)  # [batch, H, T, D]
            Gx = F.softmax(torch.einsum('bhd,bhtd->bht', Qx, Kp), dim=-1)  # [batch, H, T]

            # PaQ: プロトタイプ更新
            Vx = torch.einsum('hid,bhi->bhd', self.Wxv, x_expanded)  # [batch, H, D]
            Kx = torch.einsum('hid,bhi->bhd', self.Wxk, x_expanded)  # [batch, H, D]
            Qp = torch.einsum('htdd,bhtd->bhtd', self.Wpq, P)  # [batch, H, T, D]

            # Attention scores and values
            scores = torch.einsum('bhtd,bhd->bht', Qp, Kx)  # [batch, H, T]
            attn = F.softmax(scores, dim=-1)  # [batch, H, T]
            Ap = torch.einsum('bht,bhd->bhtd', attn, Vx)  # [batch, H, T, D]

            # ゲートとプロトタイプの更新
            G = (1 - momentum) * G + momentum * Gx
            P = torch.einsum('bht,bhtd->bhtd', G, P + Ap)

        # プロトタイプ数次元でsum
        output = P.sum(dim=2)  # [batch, H, D]
        return F.relu(output)


class FractalEncoder(nn.Module):
    """フラクタルエンコーダ"""

    def __init__(self, config):
        """
        Args:
            config: FractalConfig
        """
        super().__init__()
        self.config = config

        # 各depthのPrototypeMatchingモジュール
        self.prototype_matching = nn.ModuleList()
        for depth in range(config.max_depth):
            input_dim = config.get_input_dim_for_depth(depth)
            self.prototype_matching.append(
                PrototypeMatching(
                    input_dim=input_dim,
                    proto_dim=config.proto_dims[depth],
                    proto_num=config.proto_nums[depth],
                    head_dim=config.head_dims[depth]
                )
            )

        # ゲート初期値（各ノードごと）
        self.gate_init = nn.ParameterDict()
        self._init_gates()

    def _init_gates(self):
        """ゲート初期値を再帰的に初期化"""
        def register_gate(depth, parent_id, node_id):
            if depth >= self.config.max_depth:
                return

            key = f"d{depth}_p{parent_id}_n{node_id}"
            self.gate_init[key] = nn.Parameter(
                torch.randn(self.config.head_dims[depth], self.config.proto_nums[depth])
            )

            # 子ノードのゲートも登録
            for h in range(self.config.head_dims[depth]):
                child_parent_id = f"{parent_id}_{node_id}" if parent_id else str(node_id)
                register_gate(depth + 1, child_parent_id, h)

        register_gate(0, None, 0)

    def _get_gate_key(self, depth: int, parent_id: Optional[str], node_id: int) -> str:
        """ゲートのキーを生成"""
        return f"d{depth}_p{parent_id}_n{node_id}"

    def forward(self, x: torch.Tensor) -> Dict[Tuple, torch.Tensor]:
        """
        Args:
            x: 入力 [batch, input_dim]

        Returns:
            features_dict: 各ノードの特徴量を格納した辞書
                key: (depth, parent_id, node_id)
                value: [batch, proto_dim]
        """
        batch_size = x.shape[0]
        features_dict = {}

        def fractalize(x_node, depth, parent_id, node_id):
            """再帰的なfractalize処理"""
            # 現在のノードの特徴を保存
            node_key = (depth, parent_id, node_id)
            features_dict[node_key] = x_node

            if depth >= self.config.max_depth:
                return

            # ゲート初期値を取得
            gate_key = self._get_gate_key(depth, parent_id, node_id)
            gate = self.gate_init[gate_key].unsqueeze(0).expand(batch_size, -1, -1)

            # PrototypeMatchingで分岐
            children = self.prototype_matching[depth](
                x_node, gate,
                num_iterations=self.config.num_iterations,
                momentum=self.config.gate_momentum
            )  # [batch, H, D]

            # 各子ノードを再帰的に処理
            for h in range(self.config.head_dims[depth]):
                child_parent_id = f"{parent_id}_{node_id}" if parent_id else str(node_id)
                fractalize(
                    children[:, h, :],  # [batch, D]
                    depth + 1,
                    child_parent_id,
                    h
                )

        # ルートから開始
        fractalize(x, 0, None, 0)

        return features_dict

In [24]:
"""
Fractal Decoder Module
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Tuple, List, Optional


class BranchAttention(nn.Module):
    """ブランチごとのAttention処理（MHA + FFN）"""

    def __init__(self, dim: int, num_heads: int, use_ffn: bool = True):
        """
        Args:
            dim: 入力次元
            num_heads: アテンションヘッド数
            use_ffn: FFN使用フラグ
        """
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.use_ffn = use_ffn

        assert dim % num_heads == 0, f"dim {dim} must be divisible by num_heads {num_heads}"
        self.head_dim = dim // num_heads

        # Multi-head Attention
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.out_proj = nn.Linear(dim, dim)

        # FFN (2層、中間次元は入力の2倍)
        if use_ffn:
            hidden_dim = dim * 2
            self.ffn = nn.Sequential(
                nn.Linear(dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, dim)
            )

        self.layer_norm1 = nn.LayerNorm(dim)
        self.layer_norm2 = nn.LayerNorm(dim) if use_ffn else None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: 入力 [batch, num_tokens, dim]

        Returns:
            出力 [batch, num_tokens, dim]
        """
        batch_size, num_tokens, _ = x.shape

        # Multi-head Attention
        residual = x
        x = self.layer_norm1(x)

        # Q, K, V projection and reshape for multi-head
        q = self.q_proj(x).view(batch_size, num_tokens, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, num_tokens, self.num_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, num_tokens, self.num_heads, self.head_dim)

        # Transpose for attention: [batch, heads, tokens, head_dim]
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn = F.softmax(scores, dim=-1)

        # Apply attention to values
        out = torch.matmul(attn, v)  # [batch, heads, tokens, head_dim]

        # Concatenate heads
        out = out.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.dim)
        out = self.out_proj(out)

        # Residual connection
        x = residual + out

        # FFN
        if self.use_ffn:
            residual = x
            x = self.layer_norm2(x)
            x = residual + self.ffn(x)

        return x


class Defractalize(nn.Module):
    """トークンを統合して次元削減（token conv）"""

    def __init__(self, token_dim: int, output_dim: int):
        """
        Args:
            token_dim: トークンの次元 d_n
            output_dim: 出力次元 d_{n-1}
        """
        super().__init__()
        self.token_dim = token_dim
        self.output_dim = output_dim

        # Conv: tokens -> single vector
        self.token_conv = nn.Conv1d(token_dim, output_dim, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: 入力 [batch, num_tokens, token_dim]

        Returns:
            出力 [batch, output_dim]
        """
        # [batch, num_tokens, token_dim] -> [batch, token_dim, num_tokens]
        x = x.transpose(1, 2)

        # Conv1d: [batch, token_dim, num_tokens] -> [batch, output_dim, num_tokens]
        x = self.token_conv(x)

        # Mean pooling over tokens
        x = x.mean(dim=-1)  # [batch, output_dim]

        return x


class ResidualConnection(nn.Module):
    """エンコーダ特徴量との残差接続"""

    def __init__(self, decoder_dim: int, encoder_dim: int, output_dim: int):
        """
        Args:
            decoder_dim: デコーダ特徴量次元
            encoder_dim: エンコーダ特徴量次元
            output_dim: 出力次元
        """
        super().__init__()
        self.decoder_dim = decoder_dim
        self.encoder_dim = encoder_dim
        self.output_dim = output_dim

        # Linear transformation for residual
        self.residual_conv = nn.Linear(decoder_dim + encoder_dim, output_dim)

    def forward(self, decoder_features: torch.Tensor,
                encoder_features: torch.Tensor) -> torch.Tensor:
        """
        Args:
            decoder_features: デコーダ特徴量 [batch, decoder_dim]
            encoder_features: エンコーダ特徴量 [batch, encoder_dim]

        Returns:
            統合された特徴量 [batch, output_dim]
        """
        # Concatenate and transform
        combined = torch.cat([decoder_features, encoder_features], dim=-1)
        output = self.residual_conv(combined)
        return F.relu(output)


class FractalDecoder(nn.Module):
    """フラクタルデコーダ"""

    def __init__(self, config):
        """
        Args:
            config: FractalConfig
        """
        super().__init__()
        self.config = config

        # 各depthのモジュール
        self.branch_attention = nn.ModuleList()
        self.defractalize = nn.ModuleList()
        self.residual_connection = nn.ModuleList()

        # 各depth（0からmax_depth-1）で子ノードを統合するモジュールを作成
        for depth in range(config.max_depth):
            # 子ノードから来る特徴量の次元
            # depth=0の子（depth=1）はproto_dims[0]を返す
            # depth=1の子（depth=2）はproto_dims[1]を返す
            # depth=max_depth-1の子（最深部）はproto_dims[max_depth-1]を返す
            if depth < config.max_depth - 1:
                child_dim = config.proto_dims[depth]
            else:
                # 最深部の親：子は最深部なのでproto_dims[max_depth-1]
                child_dim = config.proto_dims[config.max_depth - 1]

            # BranchAttention：子ノードの次元に対して
            self.branch_attention.append(
                BranchAttention(
                    dim=child_dim,
                    num_heads=config.attention_heads,
                    use_ffn=config.use_ffn
                )
            )

            # Defractalize：子ノード次元から現在のノードが必要とする次元へ
            if depth > 0:
                output_dim = config.proto_dims[depth - 1]
            else:
                # ルート：子ノードproto_dims[0]からinput_dimへ
                output_dim = config.input_dim

            self.defractalize.append(
                Defractalize(
                    token_dim=child_dim,
                    output_dim=output_dim
                )
            )

            # ResidualConnection：デコーダとエンコーダの特徴量を結合
            if depth == 0:
                # ルート：
                # decoder: input_dim, encoder: input_dim -> output: input_dim
                self.residual_connection.append(
                    ResidualConnection(
                        decoder_dim=config.input_dim,
                        encoder_dim=config.input_dim,
                        output_dim=config.input_dim
                    )
                )
            else:
                # その他：同じ次元
                dim = config.proto_dims[depth - 1]
                self.residual_connection.append(
                    ResidualConnection(
                        decoder_dim=dim,
                        encoder_dim=dim,
                        output_dim=dim
                    )
                )

    def forward(self, encoder_features: Dict[Tuple, torch.Tensor]) -> torch.Tensor:
        """
        Args:
            encoder_features: エンコーダの特徴量辞書
                key: (depth, parent_id, node_id)
                value: [batch, proto_dim]

        Returns:
            出力特徴量 [batch, input_dim]
        """
        # Bottom-up処理用の辞書
        decoder_features = {}

        def process_node(depth, parent_id, node_id):
            """ノードをbottom-upで処理"""
            node_key = (depth, parent_id, node_id)

            if depth == self.config.max_depth:
                # 最深部：エンコーダ特徴量をそのまま返す（Attention不要）
                features = encoder_features[node_key]
                decoder_features[node_key] = features
                return features

            # 子ノードの特徴を収集
            children_features = []
            for h in range(self.config.head_dims[depth]):
                child_parent_id = f"{parent_id}_{node_id}" if parent_id else str(node_id)
                child_key = (depth + 1, child_parent_id, h)

                if child_key in encoder_features:
                    child_feat = process_node(depth + 1, child_parent_id, h)
                    children_features.append(child_feat)

            if not children_features:
                # 子ノードがない場合（通常あり得ない）
                return encoder_features[node_key]

            # 子ノードの特徴をスタック
            children_tensor = torch.stack(children_features, dim=1)  # [batch, H, D_child]

            # BranchAttention（depthのモジュールを使用）
            children_tensor = self.branch_attention[depth](children_tensor)

            # Defractalize: [batch, H, D_child] -> [batch, D_current]
            integrated = self.defractalize[depth](children_tensor)

            # Residual connection with encoder features
            encoder_feat = encoder_features[node_key]
            output = self.residual_connection[depth](integrated, encoder_feat)

            decoder_features[node_key] = output
            return output

        # ルートノードから処理開始（再帰的にbottom-up）
        root_features = process_node(0, None, 0)

        return root_features

In [30]:
"""
Fractal Autoencoder - 全体統合モジュール
"""

import torch
import torch.nn as nn
from typing import Dict, Tuple, Optional


class OutputHead(nn.Module):
    """最終出力層"""

    def __init__(self, input_dim: int, output_dim: int):
        """
        Args:
            input_dim: 入力次元（Decoderからのinput_dim）
            output_dim: 最終出力次元
        """
        super().__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: 入力 [batch, input_dim]

        Returns:
            出力 [batch, output_dim]
        """
        return self.fc(x)


class FractalAutoencoder(nn.Module):
    """フラクタルオートエンコーダ全体"""

    def __init__(self, config):
        """
        Args:
            config: FractalConfig
        """
        super().__init__()
        self.config = config

        # モジュール
        self.encoder = FractalEncoder(config)
        self.decoder = FractalDecoder(config)
        self.output_head = OutputHead(
            input_dim=config.input_dim,  # Decoderはinput_dimまで戻す
            output_dim=config.output_dim
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x: 入力 [batch, input_dim]

        Returns:
            出力 [batch, output_dim]
        """
        # Encode
        encoder_features = self.encoder(x)

        # Decode
        decoded_features = self.decoder(encoder_features)

        # Output projection
        output = self.output_head(decoded_features)

        return output

    def get_encoder_features(self, x: torch.Tensor) -> Dict[Tuple, torch.Tensor]:
        """エンコーダの中間特徴量を取得（デバッグ用）"""
        return self.encoder(x)

    def decode_from_features(self, encoder_features: Dict[Tuple, torch.Tensor]) -> torch.Tensor:
        """特徴量から直接デコード（デバッグ用）"""
        decoded = self.decoder(encoder_features)
        return self.output_head(decoded)


# 使用例とテストコード
def test_fractal_autoencoder():
    """動作確認用のテスト関数"""
    # 設定
    config = get_default_config()
    config.max_depth = 2  # テスト用に浅くする
    config.input_dim = 64
    config.output_dim = 10
    config.head_dims = [4, 8]
    config.proto_dims = [32, 64]
    config.proto_nums = [8, 16]

    # モデル作成
    model = FractalAutoencoder(config)

    # ダミー入力
    batch_size = 2
    x = torch.randn(batch_size, config.input_dim)

    # Forward pass
    output = model(x)
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")

    # エンコーダ特徴量の確認
    encoder_features = model.get_encoder_features(x)
    print(f"\nEncoder features:")
    for key, value in encoder_features.items():
        print(f"  Node {key}: shape {value.shape}")

    # パラメータ数
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nTotal parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    return model


# デバッグ用：勾配チェック
def check_gradients(model, x, target):
    """勾配が正しく流れるかチェック"""
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Forward
    output = model(x)
    loss = criterion(output, target)

    # Backward
    optimizer.zero_grad()
    loss.backward()

    # 勾配チェック
    print("\nGradient check:")
    for name, param in model.named_parameters():
        if param.grad is not None:
            grad_norm = param.grad.norm().item()
            print(f"  {name}: grad_norm = {grad_norm:.6f}")
            if grad_norm == 0:
                print(f"    WARNING: Zero gradient!")

    return loss.item()



In [31]:
# if __name__ == "__main__":
#     # テスト実行
#     model = test_fractal_autoencoder()

#     # 勾配チェック
#     batch_size = 2
#     x = torch.randn(batch_size, model.config.input_dim)
#     target = torch.randn(batch_size, model.config.output_dim)
#     loss = check_gradients(model, x, target)
#     print(f"\nLoss: {loss:.6f}")

Input shape: torch.Size([2, 64])
Output shape: torch.Size([2, 10])

Encoder features:
  Node (0, None, 0): shape torch.Size([2, 64])
  Node (1, '0', 0): shape torch.Size([2, 32])
  Node (2, '0_0', 0): shape torch.Size([2, 64])
  Node (2, '0_0', 1): shape torch.Size([2, 64])
  Node (2, '0_0', 2): shape torch.Size([2, 64])
  Node (2, '0_0', 3): shape torch.Size([2, 64])
  Node (2, '0_0', 4): shape torch.Size([2, 64])
  Node (2, '0_0', 5): shape torch.Size([2, 64])
  Node (2, '0_0', 6): shape torch.Size([2, 64])
  Node (2, '0_0', 7): shape torch.Size([2, 64])
  Node (1, '0', 1): shape torch.Size([2, 32])
  Node (2, '0_1', 0): shape torch.Size([2, 64])
  Node (2, '0_1', 1): shape torch.Size([2, 64])
  Node (2, '0_1', 2): shape torch.Size([2, 64])
  Node (2, '0_1', 3): shape torch.Size([2, 64])
  Node (2, '0_1', 4): shape torch.Size([2, 64])
  Node (2, '0_1', 5): shape torch.Size([2, 64])
  Node (2, '0_1', 6): shape torch.Size([2, 64])
  Node (2, '0_1', 7): shape torch.Size([2, 64])
  Node 