In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms

In [5]:
# Set random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x10e9c1ad0>

In [11]:
# choose device
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS device is available. Using MPS for acceleration.")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("CUDA device is available. Using CUDA for acceleration.")
else:
    device = torch.device("cpu")
    print("No GPU acceleration available. Using CPU.")

MPS device is available. Using MPS for acceleration.


In [12]:
# Define Model A's forward function
def model_A_forward(x, params):
    # Convolutional layer
    x = F.conv2d(x, params['conv1.weight'], params['conv1.bias'], stride=1, padding=0)
    x = F.relu(x)
    x = F.max_pool2d(x, 2)
    # Flatten
    x = x.view(x.size(0), -1)
    # Fully connected layer
    x = F.linear(x, params['fc.weight'], params['fc.bias'])
    return x

# Initialize Model A's parameters
def initialize_model_A_params(device):
    params = {
        'conv1.weight': nn.Parameter(torch.randn(10, 1, 5, 5, device=device) * 0.1),
        'conv1.bias': nn.Parameter(torch.zeros(10, device=device)),
        'fc.weight': nn.Parameter(torch.randn(10, 1440, device=device) * 0.1),  # Adjusted size
        'fc.bias': nn.Parameter(torch.zeros(10, device=device)),
    }
    return params

# Flatten parameters with consistent ordering
def flatten_params(params):
    return torch.cat([params[name].view(-1) for name in sorted(params.keys())])

# Unflatten parameters with consistent ordering
def unflatten_params(flat_params, param_shapes, device):
    params = {}
    idx = 0
    for name in sorted(param_shapes.keys()):
        shape = param_shapes[name]
        size = torch.prod(torch.tensor(shape)).item()
        params[name] = flat_params[idx:idx+size].view(shape).to(device)
        idx += size
    return params

# Model B
class ModelB(nn.Module):
    def __init__(self, input_size, output_size):
        super(ModelB, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, output_size)
        
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [14]:
# Prepare data transformations
transform = transforms.Compose([transforms.ToTensor()])

# Download and prepare the MNIST dataset
full_train_dataset = datasets.MNIST(root='../data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='../data', train=False, transform=transform, download=True)

# Split the full training dataset into training and validation sets
train_size = int(0.9 * len(full_train_dataset))  # 90% for training
val_size = len(full_train_dataset) - train_size  # 10% for validation
train_dataset, val_dataset = random_split(full_train_dataset, [train_size, val_size])

# Create DataLoaders for training, validation, and testing
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=64, shuffle=False)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


# Initialize Model A's parameters
params_A = initialize_model_A_params(device)

# Get shapes of Model A's parameters
param_shapes = {name: param.shape for name, param in params_A.items()}

# Flatten Model A's parameters
theta_A_flat = flatten_params(params_A)

# Create Model B
input_size = theta_A_flat.numel()
output_size = input_size  # Output size is the same as input size
model_B = ModelB(input_size, output_size).to(device)

# Define optimizer and loss function
optimizer_B = optim.Adam(model_B.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()



In [15]:
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer_B.zero_grad()
        
        # Pass params to B
        theta_A_prime_flat = model_B(flatten_params(params_A))
        
        # Unflatten parameters
        params_A_prime = unflatten_params(theta_A_prime_flat, param_shapes, device)
        
        # Forward pass
        outputs = model_A_forward(inputs, params_A_prime)
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Compute accuracy
        _, predicted = torch.max(outputs.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / labels.size(0)
        
        # Backward pass
        loss.backward()
        
        # Update Model B's parameters
        optimizer_B.step()
        
        # Print loss and accuracy every 100 batches
        if (batch_idx + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], '
                  f'Step [{batch_idx+1}/{len(train_loader)}], '
                  f'Loss: {loss.item():.4f}, '
                  f'Training Accuracy: {accuracy * 100:.2f}%')
    
    
    # validate
    model_B.eval()
    with torch.no_grad():
        running_val_correct = 0
        total_val = 0
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Pass flattened params to B
            theta_A_prime_flat = model_B(flatten_params(params_A))
            
            # Unflatten parameters
            params_A_prime = unflatten_params(theta_A_prime_flat, param_shapes, device)
            
            # Forward pass with updated parameters
            outputs = model_A_forward(inputs, params_A_prime)
            
            # Compute accuracy
            _, predicted = torch.max(outputs.data, 1)
            correct = (predicted == labels).sum().item()
            running_val_correct += correct
            total_val += labels.size(0)
        
        val_accuracy = (running_val_correct / total_val) * 100
        print(f'*** Epoch [{epoch+1}/{num_epochs}] Validation Accuracy: {val_accuracy:.2f}% ***\n')
        
    
    # Update Model A's parameters for the next epoch
    # Detach to prevent the graph from growing
    params_A = {name: param.detach() for name, param in params_A_prime.items()}
    



Epoch [1/5], Step [100/844], Loss: 0.4113, Training Accuracy: 84.38%
Epoch [1/5], Step [200/844], Loss: 0.1275, Training Accuracy: 96.88%
Epoch [1/5], Step [300/844], Loss: 0.1088, Training Accuracy: 96.88%
Epoch [1/5], Step [400/844], Loss: 0.4814, Training Accuracy: 87.50%
Epoch [1/5], Step [500/844], Loss: 0.1943, Training Accuracy: 96.88%
Epoch [1/5], Step [600/844], Loss: 0.1136, Training Accuracy: 96.88%
Epoch [1/5], Step [700/844], Loss: 0.1551, Training Accuracy: 96.88%
Epoch [1/5], Step [800/844], Loss: 0.0600, Training Accuracy: 98.44%
*** Epoch [1/5] Validation Accuracy: 97.02% ***

Epoch [2/5], Step [100/844], Loss: 0.0937, Training Accuracy: 98.44%
Epoch [2/5], Step [200/844], Loss: 0.1912, Training Accuracy: 95.31%
Epoch [2/5], Step [300/844], Loss: 0.0376, Training Accuracy: 98.44%
Epoch [2/5], Step [400/844], Loss: 0.1044, Training Accuracy: 98.44%
Epoch [2/5], Step [500/844], Loss: 0.0508, Training Accuracy: 98.44%
Epoch [2/5], Step [600/844], Loss: 0.0644, Training Ac