In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm import tqdm

In [10]:
class PCNetwork(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=128, output_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

        self.phi1 = torch.zeros(1, hidden_dim, requires_grad=True)
        self.phi2 = torch.zeros(1, output_dim, requires_grad=True)

        self.mu1 = torch.zeros(1, hidden_dim)
        self.mu2 = torch.zeros(1, output_dim)

    def forward(self, x):
        self.mu1 = torch.relu(self.fc1(x))
        self.mu2 = self.fc2(self.phi1)
        return self.mu2

    def update_activities(self, x, target, lr_phi=0.01, steps=20):
        self.mu1 = torch.relu(self.fc1(x))
        for _ in range(steps):
            self.mu2 = self.fc2(self.phi1)

            eps1 = self.phi1 - self.mu1
            eps2 = self.phi2 - self.mu2

            energy = 0.5 * (eps1.pow(2).sum() + 0.5 * eps2.pow(2).sum())

            energy.backward(retain_graph=True)

            with torch.no_grad():
                if self.phi1.grad is not None:
                    self.phi1 -= lr_phi * self.phi1.grad
                    self.phi1.grad.zero_()
                if self.phi2.grad is not None:
                    self.phi2 -= lr_phi * self.phi2.grad
                    self.phi2.grad.zero_()
    def update_weights(self, x, target, lr_theta=0.001):
        self.forward(x)

        with torch.no_grad():
            self.phi2 = target.clone().detach().requires_grad_(True)

        eps1 = self.phi1 - self.mu1
        eps2 = self.phi2 - self.mu2
        energy = 0.5 * (eps1.pow(2).sum() + 0.5 * eps2.pow(2).sum())

        energy.backward()

        with torch.no_grad():
            for param in self.parameters():
                if param.grad is not None:
                    param -= lr_theta * param.grad
            self.zero_grad()




In [11]:
def train_pc(model, train_loader, epochs=5, lr_phi=0.01, lr_theta=0.001):
    optimizer = optim.Adam(model.parameters(), lr=lr_theta)
    for epoch in range(epochs):
        for x, y in tqdm(train_loader):
            x = x.view(x.size(0), -1)  # Flatten MNIST
            target = torch.zeros(1, 10).scatter_(1, y.unsqueeze(0), 1)

            # E-step: Update neural activities (φ)
            model.update_activities(x, target, lr_phi=lr_phi, steps=20)

            # M-step: Update weights (θ)
            model.update_weights(x, target, lr_theta=lr_theta)

In [12]:
def test_pc(model, test_loader):
    correct = 0
    for x, y in test_loader:
        x = x.view(x.size(0), -1)
        output = model(x)
        pred = output.argmax(dim=1)
        correct += pred.eq(y).sum().item()
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f"Accuracy: {accuracy:.2f}%")

In [13]:
# Load data
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.MNIST("./data", train=True, download=True, transform=transform)
test_data = datasets.MNIST("./data", train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=1, shuffle=False)

# Train
model = PCNetwork()
train_pc(model, train_loader, epochs=5)
test_pc(model, test_loader)

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.08MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 135kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.29MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.10MB/s]
100%|██████████| 60000/60000 [11:37<00:00, 86.05it/s]
100%|██████████| 60000/60000 [11:37<00:00, 86.04it/s]
100%|██████████| 60000/60000 [11:44<00:00, 85.21it/s]
100%|██████████| 60000/60000 [11:46<00:00, 84.93it/s]
100%|██████████| 60000/60000 [11:42<00:00, 85.36it/s]


Accuracy: 11.35%
