In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class StaticDEQ(nn.Module):
    """
    Static Deep Equilibrium Model

    Args:
        T: Sequence length
        C: Input/output channel dimension
        D: Hidden dimension
        L: Number of iterations
        N: Number of weight matrices in update function
        hid_activation: Hidden activation function ('sigmoid', 'token_softmax', 'relu', 'gelu', 'tanh')
        output_activation: Output activation function ('sigmoid', 'token_softmax', 'relu', 'gelu', 'tanh', 'none')
        weight_init: Weight initialization method ('xavier_uniform', 'xavier_normal',
                                                   'kaiming_uniform', 'kaiming_normal', 'orthogonal')
        bias: Whether to use bias in linear projections (default: False)
    """

    def __init__(self, T, C, D, L, N, hid_activation='relu', output_activation='none',
                 weight_init='xavier_uniform', bias=False):
        super(StaticDEQ, self).__init__()

        self.T = T
        self.C = C
        self.D = D
        self.L = L
        self.N = N
        self.hid_activation = hid_activation
        self.output_activation = output_activation
        self.use_bias = bias

        # Input and output projections
        self.input_proj = nn.Linear(C, D, bias=bias)
        self.output_proj = nn.Linear(D, C, bias=bias)

        # N weight tensors of shape (T, T, D, D)
        self.weights = nn.ParameterList([
            nn.Parameter(torch.empty(T, T, D, D))
            for _ in range(N)
        ])

        # Initialize weights
        self._initialize_weights(weight_init)

    def _initialize_weights(self, method):
        """Initialize all weights"""
        # Initialize weight tensors
        for weight in self.weights:
            if method == 'xavier_uniform':
                nn.init.xavier_uniform_(weight.view(-1, self.D))
            elif method == 'xavier_normal':
                nn.init.xavier_normal_(weight.view(-1, self.D))
            elif method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(weight.view(-1, self.D), nonlinearity='relu')
            elif method == 'kaiming_normal':
                nn.init.kaiming_normal_(weight.view(-1, self.D), nonlinearity='relu')
            elif method == 'orthogonal':
                for i in range(self.T):
                    for j in range(self.T):
                        nn.init.orthogonal_(weight[i, j])
            else:
                raise ValueError(f"Unknown initialization method: {method}")

        # Initialize projection layers
        if method in ['xavier_uniform', 'xavier_normal']:
            if method == 'xavier_uniform':
                nn.init.xavier_uniform_(self.input_proj.weight)
                nn.init.xavier_uniform_(self.output_proj.weight)
            else:
                nn.init.xavier_normal_(self.input_proj.weight)
                nn.init.xavier_normal_(self.output_proj.weight)
        elif method in ['kaiming_uniform', 'kaiming_normal']:
            if method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(self.input_proj.weight, nonlinearity='relu')
                nn.init.kaiming_uniform_(self.output_proj.weight, nonlinearity='relu')
            else:
                nn.init.kaiming_normal_(self.input_proj.weight, nonlinearity='relu')
                nn.init.kaiming_normal_(self.output_proj.weight, nonlinearity='relu')
        elif method == 'orthogonal':
            nn.init.orthogonal_(self.input_proj.weight)
            nn.init.orthogonal_(self.output_proj.weight)

        # Initialize biases to zero if they exist
        if self.use_bias:
            nn.init.zeros_(self.input_proj.bias)
            nn.init.zeros_(self.output_proj.bias)

    def _get_activation(self, activation_name):
        """Get activation function by name"""
        if activation_name == 'sigmoid':
            return torch.sigmoid
        elif activation_name == 'token_softmax':
            return lambda x: F.softmax(x, dim=-1)
        elif activation_name == 'relu':
            return F.relu
        elif activation_name == 'gelu':
            return F.gelu
        elif activation_name == 'tanh':
            return torch.tanh
        elif activation_name == 'none':
            return lambda x: x
        else:
            raise ValueError(f"Unknown activation: {activation_name}")

    def update_function(self, X):
        """
        Apply the sequential update function
        Args:
            X: Input tensor of shape (B, T, D)
        Returns:
            Updated tensor of shape (B, T, D)
        """
        hid_act = self._get_activation(self.hid_activation)

        for n in range(self.N):
            # Apply weight_n: einsum('ttdd,btd->btd', weight_n, X)
            X_ = torch.einsum('ttdd,btd->btd', self.weights[n], X)
            # Apply activation
            X = X + hid_act(X_)

        return X

    def forward(self, X):
        """
        Forward pass through the Static DEQ model
        Args:
            X: Input tensor of shape (B, T, C)
        Returns:
            Output tensor of shape (B, T, C)
        """
        B, T, C = X.shape
        assert T == self.T, f"Input sequence length {T} doesn't match model's T={self.T}"
        assert C == self.C, f"Input channel dimension {C} doesn't match model's C={self.C}"

        # Input projection: (B, T, C) -> (B, T, D)
        X = self.input_proj(X)

        # Iterative updates
        for l in range(self.L):
            X = X + self.update_function(X)

        # Output projection: (B, T, D) -> (B, T, C)
        X = self.output_proj(X)

        # Output activation
        output_act = self._get_activation(self.output_activation)
        X = output_act(X)

        return X

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class HierarchicalDEQ(nn.Module):
    """
    Hierarchical Deep Equilibrium Model

    Args:
        C: Input/output channel dimension
        D: Hidden dimension
        Ls: List of iteration counts for each stage [l_0, ..., l_S]
        Ns: List of weight matrix counts for each stage [n_0, ..., n_S]
        s_dims: List of dimensions [d_0=T, d_1, ..., d_S]
        hid_activation: Hidden activation function ('sigmoid', 'token_softmax', 'relu', 'gelu', 'tanh')
        output_activation: Output activation function ('sigmoid', 'token_softmax', 'relu', 'gelu', 'tanh', 'none')
        weight_init: Weight initialization method ('xavier_uniform', 'xavier_normal',
                                                   'kaiming_uniform', 'kaiming_normal', 'orthogonal')
        bias: Whether to use bias in linear projections (default: False)
        weight_share: Whether to share weights across stages (default: False)
    """

    def __init__(self, C, D, Ls, Ns, s_dims, hid_activation='relu', output_activation='none',
                 weight_init='xavier_uniform', bias=False, weight_share=False):
        super(HierarchicalDEQ, self).__init__()

        self.C = C
        self.D = D
        self.Ls = Ls
        self.Ns = Ns
        self.s_dims = s_dims
        self.T = s_dims[0]  # d_0 = T
        self.S = len(Ls) - 1  # Number of stages (0 to S)
        self.hid_activation = hid_activation
        self.output_activation = output_activation
        self.use_bias = bias
        self.weight_share = weight_share

        # Total dimension for concatenated XZ
        self.total_dim = sum(s_dims)

        # Input and output projections
        self.input_proj = nn.Linear(C, D, bias=bias)
        self.output_proj = nn.Linear(D, C, bias=bias)

        # Learnable latent vectors Z_s for s = 1, ..., S (Z_0 = X, not learnable)
        self.Z = nn.ParameterList([
            nn.Parameter(torch.empty(s_dims[s], D)) for s in range(1, len(s_dims))
        ])

        # Weight matrices for update functions
        if weight_share:
            # Use n_0 for all stages when weight sharing
            N = Ns[0]
            self.weights = nn.ParameterList([
                nn.Parameter(torch.empty(self.total_dim, self.total_dim, D, D))
                for _ in range(N)
            ])
        else:
            # Separate weights for each stage
            self.weights = nn.ModuleList([
                nn.ParameterList([
                    nn.Parameter(torch.empty(self.total_dim, self.total_dim, D, D))
                    for _ in range(Ns[s])
                ])
                for s in range(self.S + 1)
            ])

        # Initialize weights
        self._initialize_weights(weight_init)

    def _initialize_weights(self, method):
        """Initialize all weights"""
        # Initialize latent vectors
        for z in self.Z:
            if method == 'xavier_uniform':
                nn.init.xavier_uniform_(z)
            elif method == 'xavier_normal':
                nn.init.xavier_normal_(z)
            elif method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(z, nonlinearity='relu')
            elif method == 'kaiming_normal':
                nn.init.kaiming_normal_(z, nonlinearity='relu')
            elif method == 'orthogonal':
                nn.init.orthogonal_(z)
            else:
                raise ValueError(f"Unknown initialization method: {method}")

        # Initialize weight tensors
        if self.weight_share:
            weights_to_init = self.weights
        else:
            weights_to_init = []
            for stage_weights in self.weights:
                weights_to_init.extend(stage_weights)

        for weight in weights_to_init:
            if method == 'xavier_uniform':
                nn.init.xavier_uniform_(weight.view(-1, self.D))
            elif method == 'xavier_normal':
                nn.init.xavier_normal_(weight.view(-1, self.D))
            elif method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(weight.view(-1, self.D), nonlinearity='relu')
            elif method == 'kaiming_normal':
                nn.init.kaiming_normal_(weight.view(-1, self.D), nonlinearity='relu')
            elif method == 'orthogonal':
                for i in range(self.total_dim):
                    for j in range(self.total_dim):
                        nn.init.orthogonal_(weight[i, j])
            else:
                raise ValueError(f"Unknown initialization method: {method}")

        # Initialize projection layers
        if method in ['xavier_uniform', 'xavier_normal']:
            if method == 'xavier_uniform':
                nn.init.xavier_uniform_(self.input_proj.weight)
                nn.init.xavier_uniform_(self.output_proj.weight)
            else:
                nn.init.xavier_normal_(self.input_proj.weight)
                nn.init.xavier_normal_(self.output_proj.weight)
        elif method in ['kaiming_uniform', 'kaiming_normal']:
            if method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(self.input_proj.weight, nonlinearity='relu')
                nn.init.kaiming_uniform_(self.output_proj.weight, nonlinearity='relu')
            else:
                nn.init.kaiming_normal_(self.input_proj.weight, nonlinearity='relu')
                nn.init.kaiming_normal_(self.output_proj.weight, nonlinearity='relu')
        elif method == 'orthogonal':
            nn.init.orthogonal_(self.input_proj.weight)
            nn.init.orthogonal_(self.output_proj.weight)

        # Initialize biases to zero if they exist
        if self.use_bias:
            nn.init.zeros_(self.input_proj.bias)
            nn.init.zeros_(self.output_proj.bias)

    def _get_activation(self, activation_name):
        """Get activation function by name"""
        if activation_name == 'sigmoid':
            return torch.sigmoid
        elif activation_name == 'token_softmax':
            return lambda x: F.softmax(x, dim=-1)
        elif activation_name == 'relu':
            return F.relu
        elif activation_name == 'gelu':
            return F.gelu
        elif activation_name == 'tanh':
            return torch.tanh
        elif activation_name == 'none':
            return lambda x: x
        else:
            raise ValueError(f"Unknown activation: {activation_name}")

    def update_function(self, XZ, stage):
        """
        Apply the update function for a specific stage, updating only the corresponding part
        Args:
            XZ: Concatenated tensor of shape (B, sum(s_dims), D)
            stage: Stage index (0 to S)
        Returns:
            Updated tensor of shape (B, sum(s_dims), D) with only stage part updated
        """
        hid_act = self._get_activation(self.hid_activation)

        if self.weight_share:
            # Use shared weights (n_0 weights)
            weights = self.weights
            N = self.Ns[0]
        else:
            # Use stage-specific weights
            weights = self.weights[stage]
            N = self.Ns[stage]

        # Determine which part to update
        if stage == 0:
            start_idx = 0
            end_idx = self.T
        else:
            start_idx = sum(self.s_dims[:stage])
            end_idx = start_idx + self.s_dims[stage]

        # Keep original input for parts we don't update
        XZ_original = XZ
        # Part to update
        XZ_part = XZ[:, start_idx:end_idx, :]

        for n in range(N):
            # Apply weight to entire concatenated vector
            XZ_transformed = torch.einsum('ttdd,btd->btd', weights[n], XZ)
            # Apply activation
            XZ_transformed = hid_act(XZ_transformed)
            # Extract only the part we want to update
            XZ_part_ = XZ_transformed[:, start_idx:end_idx, :]
            # Reconstruct XZ with updated part
            if start_idx == 0:
                if end_idx < XZ.shape[1]:
                    XZ_part = XZ_part + XZ_part_
                    XZ = torch.cat([XZ_part, XZ_original[:, end_idx:, :]], dim=1)
                else:
                    XZ = XZ + XZ_part
            else:
                if end_idx < XZ.shape[1]:
                    XZ_part = XZ_part + XZ_part_
                    XZ = torch.cat([XZ_original[:, :start_idx, :], XZ_part, XZ_original[:, end_idx:, :]], dim=1)
                else:
                    XZ_part = XZ_part + XZ_part_
                    XZ = torch.cat([XZ_original[:, :start_idx, :], XZ_part], dim=1)

        return XZ

    def run_stage(self, X, Z_all, stage):
        """
        Run iterations for a specific stage, updating only the corresponding vector
        Args:
            X: Current X tensor of shape (B, T, D)
            Z_all: List of all Z tensors [Z_1, ..., Z_S] with shape (B, d_s, D)
            stage: Current stage (0 to S)
        Returns:
            Updated X or Z_all depending on stage
        """
        B = X.shape[0]

        for l in range(self.Ls[stage]):
            # Concatenate X and all Z_s
            XZ = torch.cat([X] + Z_all, dim=1)

            # Apply update function
            XZ_new = self.update_function(XZ, stage)

            # Extract and update only the part corresponding to this stage
            if stage == 0:
                # Update X only (first T dimensions)
                X = XZ_new[:, :self.T, :]
            else:
                # Update Z_stage only
                start_idx = sum(self.s_dims[:stage])
                end_idx = start_idx + self.s_dims[stage]
                Z_all[stage - 1] = XZ_new[:, start_idx:end_idx, :]

            # Recursive call for inner stages
            if stage > 0:
                X, Z_all = self.run_stage(X, Z_all, stage - 1)

        return X, Z_all

    def forward(self, X):
        """
        Forward pass through the Hierarchical DEQ model
        Args:
            X: Input tensor of shape (B, T, C)
        Returns:
            Output tensor of shape (B, T, C)
        """
        B, T, C = X.shape
        assert T == self.T, f"Input sequence length {T} doesn't match model's T={self.T}"
        assert C == self.C, f"Input channel dimension {C} doesn't match model's C={self.C}"

        # Input projection: (B, T, C) -> (B, T, D)
        X = self.input_proj(X)

        # Expand latent vectors to batch dimension
        Z_all = [z.unsqueeze(0).expand(B, -1, -1) for z in self.Z]

        # Start hierarchical forward pass from outermost stage S
        X, _ = self.run_stage(X, Z_all, self.S)

        # Output projection: (B, T, D) -> (B, T, C)
        X = self.output_proj(X)

        # Output activation
        output_act = self._get_activation(self.output_activation)
        X = output_act(X)

        return X

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class HyperDEQ(nn.Module):
    """
    Hyper Deep Equilibrium Model

    Args:
        T: Sequence length
        C: Input/output channel dimension
        D: Hidden dimension
        L: Number of iterations
        N: Number of weight generation steps
        H: Number of heads
        E: Head dimension
        hid_activation: Hidden activation function ('sigmoid', 'token_softmax', 'relu', 'gelu', 'tanh')
        output_activation: Output activation function ('sigmoid', 'token_softmax', 'relu', 'gelu', 'tanh', 'none')
        weight_init: Weight initialization method ('xavier_uniform', 'xavier_normal',
                                                   'kaiming_uniform', 'kaiming_normal', 'orthogonal')
        bias: Whether to use bias in linear projections (default: False)
        weight_share: Whether to share Wq, Wk, Wv, Wo across n (default: False)
    """

    def __init__(self, T, C, D, L, N, H, E, hid_activation='relu', output_activation='none',
                 weight_init='xavier_uniform', bias=False, weight_share=False):
        super(HyperDEQ, self).__init__()

        self.T = T
        self.C = C
        self.D = D
        self.L = L
        self.N = N
        self.H = H
        self.E = E
        self.hid_activation = hid_activation
        self.output_activation = output_activation
        self.use_bias = bias
        self.weight_share = weight_share

        # Input and output projections
        self.input_proj = nn.Linear(C, D, bias=bias)
        self.output_proj = nn.Linear(D, C, bias=bias)

        # Weight generation parameters
        if weight_share:
            # Shared across all n
            self.Wq = nn.Parameter(torch.empty(H, D, E))
            self.Wk = nn.Parameter(torch.empty(H, D, E))
            self.Wv = nn.Parameter(torch.empty(H, D, E))
            self.Wo = nn.Parameter(torch.empty(H, E, D))
        else:
            # Separate for each n
            self.Wq = nn.ParameterList([
                nn.Parameter(torch.empty(H, D, E)) for _ in range(N)
            ])
            self.Wk = nn.ParameterList([
                nn.Parameter(torch.empty(H, D, E)) for _ in range(N)
            ])
            self.Wv = nn.ParameterList([
                nn.Parameter(torch.empty(H, D, E)) for _ in range(N)
            ])
            self.Wo = nn.ParameterList([
                nn.Parameter(torch.empty(H, E, D)) for _ in range(N)
            ])

        # Initialize weights
        self._initialize_weights(weight_init)

    def _initialize_weights(self, method):
        """Initialize all weights"""
        # Initialize weight generation parameters
        if self.weight_share:
            params_to_init = [self.Wq, self.Wk, self.Wv, self.Wo]
        else:
            params_to_init = []
            params_to_init.extend(self.Wq)
            params_to_init.extend(self.Wk)
            params_to_init.extend(self.Wv)
            params_to_init.extend(self.Wo)

        for param in params_to_init:
            # Reshape for initialization
            shape = param.shape
            if len(shape) == 3:
                # For Wq, Wk, Wv, Wo: (H, D, E) or (H, E, D)
                param_2d = param.view(-1, shape[-1])
            else:
                param_2d = param

            if method == 'xavier_uniform':
                nn.init.xavier_uniform_(param_2d)
            elif method == 'xavier_normal':
                nn.init.xavier_normal_(param_2d)
            elif method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(param_2d, nonlinearity='relu')
            elif method == 'kaiming_normal':
                nn.init.kaiming_normal_(param_2d, nonlinearity='relu')
            elif method == 'orthogonal':
                nn.init.orthogonal_(param_2d)
            else:
                raise ValueError(f"Unknown initialization method: {method}")

            # Reshape back
            if len(shape) == 3:
                param.data = param_2d.view(shape)

        # Initialize projection layers
        if method in ['xavier_uniform', 'xavier_normal']:
            if method == 'xavier_uniform':
                nn.init.xavier_uniform_(self.input_proj.weight)
                nn.init.xavier_uniform_(self.output_proj.weight)
            else:
                nn.init.xavier_normal_(self.input_proj.weight)
                nn.init.xavier_normal_(self.output_proj.weight)
        elif method in ['kaiming_uniform', 'kaiming_normal']:
            if method == 'kaiming_uniform':
                nn.init.kaiming_uniform_(self.input_proj.weight, nonlinearity='relu')
                nn.init.kaiming_uniform_(self.output_proj.weight, nonlinearity='relu')
            else:
                nn.init.kaiming_normal_(self.input_proj.weight, nonlinearity='relu')
                nn.init.kaiming_normal_(self.output_proj.weight, nonlinearity='relu')
        elif method == 'orthogonal':
            nn.init.orthogonal_(self.input_proj.weight)
            nn.init.orthogonal_(self.output_proj.weight)

        # Initialize biases to zero if they exist
        if self.use_bias:
            nn.init.zeros_(self.input_proj.bias)
            nn.init.zeros_(self.output_proj.bias)

    def _get_activation(self, activation_name):
        """Get activation function by name"""
        if activation_name == 'sigmoid':
            return torch.sigmoid
        elif activation_name == 'token_softmax':
            return lambda x: F.softmax(x, dim=-1)
        elif activation_name == 'relu':
            return F.relu
        elif activation_name == 'gelu':
            return F.gelu
        elif activation_name == 'tanh':
            return torch.tanh
        elif activation_name == 'none':
            return lambda x: x
        else:
            raise ValueError(f"Unknown activation: {activation_name}")

    def update_weight_n(self, X_exp, n):
        """
        Generate weight_n using attention mechanism
        Args:
            X_exp: Input tensor of shape (B, H, T, D)
            n: Index for weight generation parameters
        Returns:
            weight_n: Generated weight tensor of shape (B, H, T, T, D, D)
        """
        B, H, T, D = X_exp.shape

        # Get parameters for this n
        if self.weight_share:
            Wq = self.Wq
            Wk = self.Wk
            Wv = self.Wv
            Wo = self.Wo
        else:
            Wq = self.Wq[n]
            Wk = self.Wk[n]
            Wv = self.Wv[n]
            Wo = self.Wo[n]

        # Q: (B, H, T, E) = X_exp @ Wq
        Q = torch.einsum('bhtd,hde->bhte', X_exp, Wq)

        # K: (B, H, T, E) = X_exp @ Wk
        K = torch.einsum('bhtd,hde->bhte', X_exp, Wk)

        # K: (B, H, E, T) = transpose
        K = K.transpose(-2, -1)

        # A: (B, H, T, T) = softmax(Q @ K / sqrt(E), dim=-1)
        A = F.softmax(Q @ K / math.sqrt(self.E), dim=-1)

        # weight_n: (B, H, T, T, D, D) = einsum("bhti,hde,hej->bhtidj", A, Wv, Wo)
        weight_n = torch.einsum('bhti,hde,hej->bhtidj', A, Wv, Wo)

        return weight_n

    def update_function(self, X_exp, weight_n):
        """
        Apply the update function
        Args:
            X_exp: Input tensor of shape (B, H, T, D)
            weight_n: Weight tensor of shape (B, H, T, T, D, D)
        Returns:
            Updated tensor of shape (B, H, T, D)
        """
        hid_act = self._get_activation(self.hid_activation)

        # Apply weight_n: einsum('bhttdd,bhtd->bhtd', weight_n, X_exp)
        X_exp_ = torch.einsum('bhttdd,bhtd->bhtd', weight_n, X_exp)

        # Apply activation
        X_exp_ = hid_act(X_exp_)

        return X_exp_

    def forward(self, X):
        """
        Forward pass through the Hyper DEQ model
        Args:
            X: Input tensor of shape (B, T, C)
        Returns:
            Output tensor of shape (B, T, C)
        """
        B, T, C = X.shape
        assert T == self.T, f"Input sequence length {T} doesn't match model's T={self.T}"
        assert C == self.C, f"Input channel dimension {C} doesn't match model's C={self.C}"

        # Input projection: (B, T, C) -> (B, T, D)
        X = self.input_proj(X)

        # Give H dimension to X: (B, T, D) -> (B, H, T, D)
        X_exp = X.unsqueeze(1).expand(-1, self.H, -1, -1)

        # Main loop
        for n in range(self.N):
            # Generate weight_n
            weight_n = self.update_weight_n(X_exp, n)

            # Apply L iterations
            for l in range(self.L):
                X_exp = X_exp + self.update_function(X_exp, weight_n)

        # Average over H dimension: (B, H, T, D) -> (B, T, D)
        X = X_exp.mean(dim=1)

        # Output projection: (B, T, D) -> (B, T, C)
        X = self.output_proj(X)

        # Output activation
        output_act = self._get_activation(self.output_activation)
        X = output_act(X)

        return X