In [6]:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ---------------------------
# Configuration Parameters
# ---------------------------
config = {
    'batch_size': 64,
    'test_batch_size': 1000,
    'epochs': 5,
    'learning_rate': 1e-3,
    'num_comp_neurons': 1,  # Number of computation neurons in CustomNeuron
    'log_interval': 100,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}

class CustomNeuron(nn.Module):
    def __init__(self, in_features, out_features, num_comp_neurons=1, bias=True):
        super(CustomNeuron, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_comp_neurons = num_comp_neurons
        self.bias = bias

        if self.num_comp_neurons < 1:
            raise ValueError("num_comp_neurons must be at least 1.")

        # Selection layer: decides which computation neuron to activate
        self.selection_layer = nn.Linear(in_features, out_features * self.num_comp_neurons, bias=bias)

        # Computation neurons: multiple linear transformations
        self.comp_weights = nn.Parameter(torch.Tensor(self.num_comp_neurons, out_features, in_features))
        if bias:
            self.comp_biases = nn.Parameter(torch.Tensor(self.num_comp_neurons, out_features))
        else:
            self.comp_biases = None

        self.reset_parameters()

    def reset_parameters(self):
        # Initialize computation neurons
        nn.init.xavier_uniform_(self.comp_weights)
        if self.comp_biases is not None:
            nn.init.zeros_(self.comp_biases)
        # Initialize selection layer
        nn.init.xavier_uniform_(self.selection_layer.weight)
        if self.selection_layer.bias is not None:
            nn.init.zeros_(self.selection_layer.bias)

    def forward(self, input):
        """
        Args:
            input: Tensor of shape [batch_size, in_features]
        Returns:
            output: Tensor of shape [batch_size, out_features]
        """
        batch_size, in_features = input.size()

        # Detach input for selection neuron to prevent gradients from flowing back to previous layers
        input_detached = input.detach()

        # Compute selection logits using detached input
        selection_logits = self.selection_layer(input_detached)  # [b, o*n]
        selection_logits = selection_logits.view(batch_size, self.out_features, self.num_comp_neurons)  # [b, o, n]

        # Compute selection probabilities using softmax over computation neurons
        selection_probs = F.softmax(selection_logits, dim=-1)  # [b, o, n]

        # Hard selection using argmax
        selected_idx = torch.argmax(selection_probs, dim=-1)  # [b, o]
        selected_mask_hard = F.one_hot(selected_idx, num_classes=self.num_comp_neurons).float()  # [b, o, n]

        # Use Straight-Through Estimator (STE)
        selected_mask = (selected_mask_hard - selection_probs).detach() + selection_probs  # [b, o, n]

        # Compute outputs from all computation neurons using the original input
        comp_weights_transposed = self.comp_weights.permute(0, 2, 1)  # [n, in_features, o]
        outputs_all = torch.einsum('bi,nio->bno', input, comp_weights_transposed)  # [b, n, o]

        # Add biases if present
        if self.comp_biases is not None:
            comp_biases = self.comp_biases.unsqueeze(0)  # [1, n, o]
            outputs_all = outputs_all + comp_biases  # [b, n, o]

        # Apply selected mask
        selected_mask_transposed = selected_mask.permute(0, 2, 1)  # [b, n, o]
        output = torch.sum(outputs_all * selected_mask_transposed, dim=1)  # Sum over n -> [b, o]

        return output  # [b, out_features]

# ---------------------------
# MNIST Dataset
# ---------------------------
def load_data(config):
    # Data transformations
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # Mean and std for MNIST
    ])

    # Training and test datasets
    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('../data', train=False, transform=transform)

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config['test_batch_size'], shuffle=False)

    return train_loader, test_loader

# ---------------------------
# MLP Model with Custom Neuron
# ---------------------------
class MLPWithCustomNeuron(nn.Module):
    def __init__(self, num_comp_neurons):
        super(MLPWithCustomNeuron, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = CustomNeuron(28 * 28, 256, num_comp_neurons=num_comp_neurons)
        self.relu1 = nn.ReLU()
        self.fc2 = CustomNeuron(256, 128, num_comp_neurons=num_comp_neurons)
        self.relu2 = nn.ReLU()
        self.fc3 = CustomNeuron(128, 10, num_comp_neurons=num_comp_neurons)

    def forward(self, x):
        x = self.flatten(x)       # [batch_size, 784]
        x = self.fc1(x)           # [batch_size, 256]
        x = self.relu1(x)
        x = self.fc2(x)           # [batch_size, 128]
        x = self.relu2(x)
        x = self.fc3(x)           # [batch_size, 10]
        return x

# ---------------------------
# Training and Testing Functions
# ---------------------------
def train(model, device, train_loader, optimizer, epoch, config):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)  # [batch_size, 10]
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % config['log_interval'] == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

def test(model, device, test_loader):
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='sum')  # Sum the loss over the batch
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)  # [batch_size, 10]
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)  # Get index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print(f'\nTest set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.2f}%)\n')

# ---------------------------
# Main Function
# ---------------------------
def main():
    # Load data
    train_loader, test_loader = load_data(config)

    # Initialize model
    model = MLPWithCustomNeuron(num_comp_neurons=config['num_comp_neurons']).to(config['device'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])

    # Training loop
    for epoch in range(1, config['epochs'] + 1):
        train(model, config['device'], train_loader, optimizer, epoch, config)
        test(model, config['device'], test_loader)

if __name__ == '__main__':
    main()



Test set: Average loss: 0.1201, Accuracy: 9641/10000 (96.41%)


Test set: Average loss: 0.0918, Accuracy: 9692/10000 (96.92%)


Test set: Average loss: 0.0823, Accuracy: 9734/10000 (97.34%)


Test set: Average loss: 0.0941, Accuracy: 9738/10000 (97.38%)


Test set: Average loss: 0.0866, Accuracy: 9735/10000 (97.35%)



In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# ---------------------------
# Configuration Parameters
# ---------------------------
config = {
    'batch_size': 64,
    'test_batch_size': 1000,
    'epochs': 5,
    'learning_rate': 1e-3,
    'num_comp_neurons': 3,  # Number of computation neurons in CustomNeuron
    'log_interval': 100,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
}

class CustomNeuron(nn.Module):
    def __init__(self, in_features, out_features, num_comp_neurons=1, bias=True):
        super(CustomNeuron, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_comp_neurons = num_comp_neurons
        self.bias = bias

        if self.num_comp_neurons < 1:
            raise ValueError("num_comp_neurons must be at least 1.")

        # Selection layer: decides which computation neuron to activate
        self.selection_layer = nn.Linear(in_features, out_features * self.num_comp_neurons, bias=bias)

        # Computation neurons: multiple linear transformations
        # Corrected shape: [n, i, o] instead of [n, o, i]
        self.comp_weights = nn.Parameter(torch.Tensor(self.num_comp_neurons, in_features, out_features))
        if bias:
            self.comp_biases = nn.Parameter(torch.Tensor(self.num_comp_neurons, out_features))
        else:
            self.comp_biases = None

        self.reset_parameters()

        # Variables to store selection information for verification
        self.last_selection_probs = None
        self.last_selected_idx = None
        self.last_entropy = None  # Initialize entropy

    def reset_parameters(self):
        # Initialize computation neurons
        nn.init.xavier_uniform_(self.comp_weights)
        if self.comp_biases is not None:
            nn.init.zeros_(self.comp_biases)
        # Initialize selection layer
        nn.init.xavier_uniform_(self.selection_layer.weight)
        if self.selection_layer.bias is not None:
            nn.init.zeros_(self.selection_layer.bias)

    def forward(self, input):
        """
        Args:
            input: Tensor of shape [batch_size, in_features]
        Returns:
            output: Tensor of shape [batch_size, out_features]
        """
        batch_size, in_features = input.size()

        # Debugging: Print input shape
        #print(f'[DEBUG] Input shape: {input.shape}')  # Should be [b, in_features]

        # Detach input for selection neuron to prevent gradients from flowing back to previous layers
        input_detached = input.detach()

        # Compute selection logits using detached input
        selection_logits = self.selection_layer(input_detached)  # [b, o*n]
        selection_logits = selection_logits.view(batch_size, self.out_features, self.num_comp_neurons)  # [b, o, n]

        # Compute selection probabilities using softmax over computation neurons
        selection_probs = F.softmax(selection_logits, dim=-1)  # [b, o, n]
        self.last_selection_probs = selection_probs  # Store for verification

        # Compute entropy of selection probabilities
        entropy = -torch.sum(selection_probs * torch.log(selection_probs + 1e-12), dim=-1)  # [b, o]
        self.last_entropy = entropy.mean().item()  # Average entropy over batch and output features

        # Hard selection using argmax
        selected_idx = torch.argmax(selection_probs, dim=-1)  # [b, o]
        self.last_selected_idx = selected_idx  # Store for verification

        selected_mask_hard = F.one_hot(selected_idx, num_classes=self.num_comp_neurons).float()  # [b, o, n]

        # Use Straight-Through Estimator (STE)
        selected_mask = (selected_mask_hard - selection_probs).detach() + selection_probs  # [b, o, n]



        # Compute outputs from all computation neurons using the original input
        # Using torch.einsum for correct tensor operations
        # input: [b, i]
        # self.comp_weights: [n, i, o]
        # outputs_all: [b, n, o] via 'bi,nio->bno'
        #print(f'[DEBUG] comp_weights shape: {self.comp_weights.shape}')  # Should be [n, i, o]

        # Check if in_features match
        expected_in_features = self.in_features
        actual_in_features = self.comp_weights.shape[1]  # Index 1 corresponds to 'i'
        if actual_in_features != expected_in_features:
            raise ValueError(f'In CustomNeuron: comp_weights has in_features={actual_in_features}, '
                             f'but expected {expected_in_features}.')

        # Perform tensor contraction using einsum
        outputs_all = torch.einsum('bi,nio->bno', input, self.comp_weights)  # [b, n, o]

        # Add biases if present
        if self.comp_biases is not None:
            comp_biases = self.comp_biases.unsqueeze(0)  # [1, n, o]
            outputs_all = outputs_all + comp_biases  # [b, n, o]

        # Apply selected mask
        selected_mask_transposed = selected_mask.permute(0, 2, 1)  # [b, n, o]
        output = torch.sum(outputs_all * selected_mask_transposed, dim=1)  # Sum over n -> [b, o]

        return output  # [b, out_features]

# ---------------------------
# MNIST Dataset
# ---------------------------
def load_data(config):
    # Data transformations
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # Mean and std for MNIST
    ])

    # Training and test datasets
    train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('../data', train=False, transform=transform)

    # Data loaders
    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=config['test_batch_size'], shuffle=False)

    return train_loader, test_loader

# ---------------------------
# MLP Model with Custom Neuron
# ---------------------------
class MLPWithCustomNeuron(nn.Module):
    def __init__(self, num_comp_neurons):
        super(MLPWithCustomNeuron, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = CustomNeuron(in_features=28 * 28, out_features=256, num_comp_neurons=num_comp_neurons)
        self.relu1 = nn.ReLU()
        self.fc2 = CustomNeuron(in_features=256, out_features=128, num_comp_neurons=num_comp_neurons)
        self.relu2 = nn.ReLU()
        self.fc3 = CustomNeuron(in_features=128, out_features=10, num_comp_neurons=num_comp_neurons)

    def forward(self, x):
        x = self.flatten(x)       # [batch_size, 784]
        x = self.fc1(x)           # [batch_size, 256]
        x = self.relu1(x)
        x = self.fc2(x)           # [batch_size, 128]
        x = self.relu2(x)
        x = self.fc3(x)           # [batch_size, 10]
        return x

# ---------------------------
# Training and Testing Functions
# ---------------------------
def train(model, device, train_loader, optimizer, epoch, config):
    model.train()
    criterion = nn.CrossEntropyLoss()
    
    # Initialize counters for verification
    total_entropy = 0.0
    total_samples = 0
    comp_neuron_selection = [0] * config['num_comp_neurons']  # To track selection frequency

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)  # [batch_size, 10]
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # Verification: Access the last entropy and selection indices from each CustomNeuron
        entropy_sum = 0.0
        for layer in model.children():
            if isinstance(layer, CustomNeuron):
                entropy_sum += layer.last_entropy
                # Count selection frequencies
                selected_indices = layer.last_selected_idx.cpu().numpy().flatten()
                for idx in selected_indices:
                    comp_neuron_selection[idx] += 1

        batch_size = data.size(0)
        total_entropy += entropy_sum * batch_size
        total_samples += batch_size

        if batch_idx % config['log_interval'] == 0:
            avg_entropy = total_entropy / total_samples
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}\t'
                  f'Avg Entropy: {avg_entropy:.4f}')

    # After each epoch, print selection statistics
    print(f'Epoch {epoch} Selection Statistics:')
    for i, count in enumerate(comp_neuron_selection):
        print(f'  Computation Neuron {i}: Selected {count} times')

def test(model, device, test_loader):
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='sum')  # Sum the loss over the batch
    test_loss = 0
    correct = 0

    # Variables for verification
    total_entropy = 0.0
    total_samples = 0
    comp_neuron_selection = [0] * config['num_comp_neurons']

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)  # [batch_size, 10]
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)  # Get index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            # Verification: Access the last entropy and selection indices from each CustomNeuron
            entropy_sum = 0.0
            for layer in model.children():
                if isinstance(layer, CustomNeuron):
                    entropy_sum += layer.last_entropy
                    # Count selection frequencies
                    selected_indices = layer.last_selected_idx.cpu().numpy().flatten()
                    for idx in selected_indices:
                        comp_neuron_selection[idx] += 1

            batch_size = data.size(0)
            total_entropy += entropy_sum * batch_size
            total_samples += batch_size

    test_loss /= len(test_loader.dataset)
    avg_entropy = total_entropy / total_samples

    print(f'\nTest set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({100. * correct / len(test_loader.dataset):.2f}%)\t'
          f'Avg Entropy: {avg_entropy:.4f}\n')

    # After testing, print selection statistics
    print(f'Test Set Selection Statistics:')
    for i, count in enumerate(comp_neuron_selection):
        print(f'  Computation Neuron {i}: Selected {count} times')

# ---------------------------
# Gradient Verification Hooks
# ---------------------------
def verify_gradients(model):
    """
    Verifies that gradients are flowing only through the selected computation neurons
    and the selection neurons. This function should be called after backward() and before
    optimizer.step().
    """
    for name, param in model.named_parameters():
        if 'comp_weights' in name or 'comp_biases' in name:
            if param.grad is not None:
                # Check if gradient is sparse or has many zeros, indicating inactive neurons
                num_zero_grads = torch.sum(param.grad == 0).item()
                total_elements = param.grad.numel()
                zero_grad_ratio = num_zero_grads / total_elements
                print(f'Gradient Check - {name}: {zero_grad_ratio*100:.2f}% zeros in gradients')
        elif 'selection_layer' in name:
            if param.grad is not None:
                # Ensure that gradients are flowing through selection layer
                grad_norm = param.grad.norm().item()
                print(f'Gradient Check - {name}: Gradient norm = {grad_norm:.6f}')

# ---------------------------
# Main Function
# ---------------------------
def main():
    # Load data
    train_loader, test_loader = load_data(config)

    # Initialize model
    model = MLPWithCustomNeuron(num_comp_neurons=config['num_comp_neurons']).to(config['device'])
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])

    # Training loop
    for epoch in range(1, config['epochs'] + 1):
        train(model, config['device'], train_loader, optimizer, epoch, config)
        test(model, config['device'], test_loader)

    # After training, perform gradient verification on a sample batch
    model.train()
    sample_data, sample_target = next(iter(train_loader))
    sample_data, sample_target = sample_data.to(config['device']), sample_target.to(config['device'])
    optimizer.zero_grad()
    output = model(sample_data)
    loss = nn.CrossEntropyLoss()(output, sample_target)
    loss.backward()

    print('\nPerforming Gradient Verification:')
    verify_gradients(model)

if __name__ == '__main__':
    main()


Epoch 1 Selection Statistics:
  Computation Neuron 0: Selected 8002692 times
  Computation Neuron 1: Selected 7804890 times
  Computation Neuron 2: Selected 7832418 times

Test set: Average loss: 0.1655, Accuracy: 9512/10000 (95.12%)	Avg Entropy: 0.4662

Test Set Selection Statistics:
  Computation Neuron 0: Selected 1419877 times
  Computation Neuron 1: Selected 1237716 times
  Computation Neuron 2: Selected 1282407 times
Epoch 2 Selection Statistics:
  Computation Neuron 0: Selected 8415331 times
  Computation Neuron 1: Selected 7569419 times
  Computation Neuron 2: Selected 7655250 times

Test set: Average loss: 0.1368, Accuracy: 9568/10000 (95.68%)	Avg Entropy: 0.4476

Test Set Selection Statistics:
  Computation Neuron 0: Selected 1400437 times
  Computation Neuron 1: Selected 1270953 times
  Computation Neuron 2: Selected 1268610 times
Epoch 3 Selection Statistics:
  Computation Neuron 0: Selected 8628825 times
  Computation Neuron 1: Selected 7504886 times
  Computation Neuron 2