In [1]:
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=512):
        super(Discriminator, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),  # First Layer: 1024 -> 512
            nn.ReLU(),                         # Activation Function
            nn.Linear(hidden_dim, 1),          # Second Layer: 512 -> 1
            nn.Sigmoid()                       # Output probability
        )

    def forward(self, h, d):
        """
        h: Tensor of shape (batch, 18, 512)
        d: Tensor of shape (batch, 18, 512)
        """
        x = torch.cat([h, d], dim=-1)  # Concatenate along feature dim -> (batch, 18, 1024)
        x = self.mlp(x)  # Pass through MLP
        return x  # Output shape: (batch, 18, 1)

# Example usage
batch_size = 16
h = torch.randn(batch_size, 18, 512)  # Example tensor for h
d = torch.randn(batch_size, 18, 512)  # Example tensor for d

discriminator = Discriminator()
output = discriminator(h, d)
print(output.shape)  # Should be (batch, 18, 1)


torch.Size([16, 18, 1])


In [4]:
import torch
import torch.nn as nn

class MI_Discriminator(nn.Module):
    def __init__(self, input_dim=1024, hidden_dim=512, num_layers=2):
        """
        input_dim: Dimension of input features (1024 after concatenation of c and s)
        hidden_dim: Number of neurons in hidden layers
        num_layers: Number of layers in the MLP (excluding input and output layers)
        """
        super(MI_Discriminator, self).__init__()

        layers = []
        prev_dim = input_dim

        # Dynamically add hidden layers
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(nn.ReLU())  # Activation function
            prev_dim = hidden_dim  # Update previous dimension

        # Final layer (hidden_dim -> 1)
        layers.append(nn.Linear(hidden_dim, 1))
        layers.append(nn.Sigmoid())  # Ensure output is in range [0,1]

        # Define MLP
        self.mlp = nn.Sequential(*layers)

    def forward(self, c, s):
        """
        c: Tensor of shape (batch, 18, 512)
        s: Tensor of shape (batch, 18, 512)
        """
        x = torch.cat([c, s], dim=-1)  # Shape: (batch, 18, 1024)
        x = self.mlp(x)  # Shape: (batch, 18, 1)
        x = x.mean(dim=1)  # Aggregate across the 18 features â†’ (batch, 1)
        return x

# Example Usage
batch_size = 16
c = torch.randn(batch_size, 18, 512)  # Example tensor for c
s = torch.randn(batch_size, 18, 512)  # Example tensor for s

# Initialize with different layer depths
discriminator_2_layers = MI_Discriminator(num_layers=2)
discriminator_4_layers = MI_Discriminator(num_layers=4)

output_2_layers = discriminator_2_layers(c, s)
output_4_layers = discriminator_4_layers(c, s)

print(output_2_layers.shape)  # Should be (batch, 1)
print(output_4_layers.shape)  # Should be (batch, 1)


torch.Size([16, 1])
torch.Size([16, 1])
