In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np

In [11]:
# --- Environment Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def squash(s, dim=-1, epsilon=1e-7):
    squared_norm = (s ** 2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * (s / (torch.sqrt(squared_norm) + epsilon))

class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_routes, in_dim, out_dim, routing_iters=3):
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.num_routes = num_routes
        self.routing_iters = routing_iters
        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_dim, in_dim))

    def forward(self, x):
        batch_size = x.size(0)
        x = x.unsqueeze(2).unsqueeze(-1)  # [batch_size, num_routes, 1, in_dim, 1]
        u_hat = torch.matmul(self.W, x).squeeze(-1)  # [1, num_routes, num_capsules, out_dim] → broadcasted

        b_ij = torch.zeros(batch_size, self.num_routes, self.num_capsules, device=device)

        for _ in range(self.routing_iters):
            c_ij = torch.softmax(b_ij, dim=2).unsqueeze(-1)  # [batch_size, num_routes, num_capsules, 1]
            s_j = (c_ij * u_hat).sum(dim=1)  # [batch_size, num_capsules, out_dim]
            v_j = squash(s_j)
            if _ < self.routing_iters - 1:
                agreement = (u_hat * v_j.unsqueeze(1)).sum(dim=-1)  # [batch_size, num_routes, num_capsules]
                b_ij += agreement
        return v_j  # [batch_size, num_capsules, out_dim]

class DRBCModel(nn.Module):
    def __init__(self, input_dim, num_routes, num_capsules, out_dim):
        super(DRBCModel, self).__init__()
        self.capsule_layer = CapsuleLayer(num_capsules, num_routes, input_dim, out_dim)
        self.fc = nn.Linear(num_capsules * out_dim, 2)  # Binary classification

    def forward(self, x):
        x = self.capsule_layer(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# --- Preprocessing ---
def generate_random_data(samples=256, num_routes=8, feature_dim=16):
    X = torch.randn(samples, num_routes, feature_dim)
    y = torch.randint(0, 2, (samples,))
    return X, y

X, y = generate_random_data()
train_size = int(0.8 * len(X))
train_dataset = TensorDataset(X[:train_size], y[:train_size])
test_dataset = TensorDataset(X[train_size:], y[train_size:])
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8)

In [12]:
# --- Train Base Model (DRBC) ---
model = DRBCModel(input_dim=16, num_routes=8, num_capsules=4, out_dim=8).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def train_model(model, loader, optimizer, criterion, epochs=3):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            
            if outputs.size(0) != targets.size(0):
                targets = targets[:outputs.size(0)]  # Align sizes

            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / len(loader):.4f}")

train_model(model, train_loader, optimizer, criterion)

Epoch 1/3, Loss: 0.7066
Epoch 2/3, Loss: 0.6953
Epoch 3/3, Loss: 0.6853


In [13]:
# --- Planning (Capsule Structure Exploration & Routing Analysis) ---
def analyze_capsule_outputs(model, loader):
    model.eval()
    with torch.no_grad():
        for inputs, _ in loader:
            inputs = inputs.to(device)
            outputs = model.capsule_layer(inputs)
            print(f"Capsule outputs shape: {outputs.shape}")  # Expected: [batch_size, num_capsules, out_dim]
            break

analyze_capsule_outputs(model, test_loader)

# --- Fine-Tune Model ---
optimizer = optim.Adam(model.parameters(), lr=0.0005)
train_model(model, train_loader, optimizer, criterion)

# --- Evaluate ---
def evaluate_model(model, loader):
    model.eval()
    correct = total = 0
    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            predictions = outputs.argmax(dim=1)
            correct += (predictions == targets).sum().item()
            total += targets.size(0)
    print(f"Accuracy: {100 * correct / total:.2f}%")

evaluate_model(model, test_loader)

# --- Deploy Policy ---
torch.save(model.state_dict(), "drbc_model.pth")
print("Model deployed (saved) successfully.")

Capsule outputs shape: torch.Size([8, 4, 8])
Epoch 1/3, Loss: 0.6763
Epoch 2/3, Loss: 0.6695
Epoch 3/3, Loss: 0.6662
Accuracy: 46.15%
Model deployed (saved) successfully.
