In [None]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import os

# Define a simple CNN model (this can be replaced with a more complex model like ResNet, VGG, etc.)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Train the model
def train_model(model, train_loader, epochs=5, learning_rate=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(train_loader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}')

    print('Finished Training')
    return model

# Load CIFAR10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Initialize models
model_A = SimpleCNN()
model_B = SimpleCNN()

# Check if pre-trained models exist
if not os.path.exists('model_A.pth'):
    print("Training Model A...")
    model_A = train_model(model_A, train_loader)
    torch.save(model_A.state_dict(), 'model_A.pth')  # Save the model

if not os.path.exists('model_B.pth'):
    print("Training Model B...")
    model_B = train_model(model_B, train_loader)
    torch.save(model_B.state_dict(), 'model_B.pth')  # Save the model

# Load pre-trained models if they exist
if os.path.exists('model_A.pth'):
    print("Loading Model A...")
    model_A.load_state_dict(torch.load('model_A.pth'))

if os.path.exists('model_B.pth'):
    print("Loading Model B...")
    model_B.load_state_dict(torch.load('model_B.pth'))

# Set models to evaluation mode
model_A.eval()
model_B.eval()

# Evaluation function
def evaluate_model(model, test_loader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

# Test the models after loading or training
accuracy_A = evaluate_model(model_A, test_loader)
accuracy_B = evaluate_model(model_B, test_loader)

print(f'Model A Accuracy: {accuracy_A:.2f}%')
print(f'Model B Accuracy: {accuracy_B:.2f}%')

# You can further extend the code here to perform the model merging techniques


Files already downloaded and verified
Files already downloaded and verified
Training Model A...
Epoch [1/5], Loss: 1.2748
Epoch [2/5], Loss: 0.8913
Epoch [3/5], Loss: 0.7051
Epoch [4/5], Loss: 0.5386
Epoch [5/5], Loss: 0.3801
Finished Training
Training Model B...
Epoch [1/5], Loss: 1.2941
Epoch [2/5], Loss: 0.9054
Epoch [3/5], Loss: 0.7280
Epoch [4/5], Loss: 0.5704
Epoch [5/5], Loss: 0.4269
Finished Training
Loading Model A...
Loading Model B...


  model_A.load_state_dict(torch.load('model_A.pth'))
  model_B.load_state_dict(torch.load('model_B.pth'))


Model A Accuracy: 73.84%
Model B Accuracy: 73.37%


In [None]:
def base_models_avg(models, test_loader):
    accuracies = []
    for model in models:
        acc = evaluate_model(model, test_loader)
        accuracies.append(acc)
    avg_accuracy = np.mean(accuracies)
    print(f"Base Models Avg Accuracy: {avg_accuracy:.2f}%")
    return avg_accuracy

# Evaluate Base Models Avg
models = [model_A, model_B]  # You can add more models if available
base_avg_accuracy = base_models_avg(models, test_loader)


Base Models Avg Accuracy: 73.61%


In [None]:
def ensemble_models(models, test_loader):
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = [model(images) for model in models]
            ensemble_output = torch.mean(torch.stack(outputs), dim=0)  # Averaging logits
            _, predicted = torch.max(ensemble_output, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f"Ensemble Accuracy: {accuracy:.2f}%")
    return accuracy

# Evaluate Ensemble Accuracy
ensemble_accuracy = ensemble_models(models, test_loader)


Ensemble Accuracy: 76.27%


In [None]:
def direct_averaging(models):
    # Ensure the same architecture for all models
    merged_model = models[0]  # Use model A's architecture as base
    with torch.no_grad():
        for param_A, *param_others in zip(models[0].parameters(), *[m.parameters() for m in models[1:]]):
            # Average all model parameters
            averaged_weights = torch.mean(torch.stack([param_A] + param_others), dim=0)
            param_A.data.copy_(averaged_weights)
    return merged_model

# Merge and evaluate Direct Averaging
direct_avg_model = direct_averaging([model_A, model_B])
direct_avg_accuracy = evaluate_model(direct_avg_model, test_loader)
print(f"Direct Averaging Accuracy: {direct_avg_accuracy:.2f}%")


Direct Averaging Accuracy: 29.57%


In [None]:
import torch

# Function to generate a permutation matrix
def get_permutation_matrix(param_A, param_B):
    # Ensure only layers with the same number of neurons are permuted
    if param_A.size(0) == param_B.size(0):
        return torch.randperm(param_A.size(0))
    else:
        return None  # Skip if the layers are not compatible

# Function to apply the permutation
def apply_permutation(param_B, permutation):
    # Apply permutation based on layer type (fully connected or convolutional)
    if permutation is not None:
        if param_B.dim() == 2:  # Fully connected layer
            return param_B[permutation, :]  # Permuting rows (neurons)
        elif param_B.dim() == 1:  # Bias layer
            return param_B[permutation]
        elif param_B.dim() == 4:  # Convolutional layer
            return param_B[permutation, :, :, :]  # Permute filters (output channels)
    return param_B  # If no permutation, return param_B unchanged

# Main function to permute neurons
def permute_neurons(model_A, model_B):
    with torch.no_grad():
        for param_A, param_B in zip(model_A.parameters(), model_B.parameters()):
            if param_A.shape == param_B.shape:  # Ensure layers have the same shape
                # Get the permutation matrix for matching neurons
                permutation_matrix = get_permutation_matrix(param_A, param_B)
                if permutation_matrix is not None:
                    permuted_param_B = apply_permutation(param_B, permutation_matrix)
                    if param_A.data.shape == permuted_param_B.shape:  # Check if blending is safe
                        param_A.data = 0.5 * (param_A.data + permuted_param_B)
                    else:
                        print(f"Shape mismatch after permutation: {param_A.shape} vs {permuted_param_B.shape}")
                else:
                    print(f"Skipping incompatible layer due to mismatch: {param_A.shape}")
            else:
                print(f"Skipping incompatible layer with shapes {param_A.shape} and {param_B.shape}")

    return model_A

# Assuming model_A and model_B are already defined and initialized somewhere above this code
permute_model = permute_neurons(model_A, model_B)
permute_accuracy = evaluate_model(permute_model, test_loader)  # Assuming evaluate_model is defined
print(f"Permute Accuracy: {permute_accuracy:.2f}%")

Permute Accuracy: 11.38%


In [None]:
import torch

def dummy_optimal_transport_align(model_A, model_B):
    # Placeholder function for optimal transport neuron alignment
    # We'll simply average the weights of the two models.
    aligned_model = model_A  # Start with model_A as the base

    with torch.no_grad():
        # Aligning parameters between model_A and model_B
        for param_A, param_B in zip(model_A.parameters(), model_B.parameters()):
            param_A.data = (param_A.data + param_B.data) / 2  # Averaging the weights
    return aligned_model

def ot_fusion(models):
    merged_model = models[0]  # Use the first model as the base

    # Loop over model parameters and average weights
    with torch.no_grad():
        for params in zip(*[model.parameters() for model in models]):
            # Stack the parameters only if their shapes match
            stacked_params = torch.stack([p.data for p in params if p.shape == params[0].shape])
            # Average the stacked parameters
            params[0].data.copy_(torch.mean(stacked_params, dim=0))

    return merged_model

# Merge and evaluate OT Fusion
ot_model = ot_fusion([model_A, model_B])
ot_fusion_accuracy = evaluate_model(ot_model, test_loader)
print(f"OT Fusion Accuracy: {ot_fusion_accuracy:.2f}%")


OT Fusion Accuracy: 72.81%


In [None]:
def matching_weights(models):
    merged_model = models[0]  # Use model A as base
    with torch.no_grad():
        for param_A, param_B in zip(models[0].parameters(), models[1].parameters()):
            # Directly match weights by averaging
            param_A.data = 0.5 * (param_A.data + param_B.data)
    return merged_model

# Merge and evaluate Matching Weights
matching_weights_model = matching_weights([model_A, model_B])
matching_weights_accuracy = evaluate_model(matching_weights_model, test_loader)
print(f"Matching Weights Accuracy: {matching_weights_accuracy:.2f}%")


Matching Weights Accuracy: 73.12%


In [None]:
import torch

def zipit_merge_function(models):
    # Placeholder function for ZipIt-like model merging
    # We'll simply average the weights of the models here.
    merged_model = models[0]  # Start with model_A as the base

    with torch.no_grad():
        # Averaging parameters between models
        for params in zip(*[model.parameters() for model in models]):
            # Ensure that all parameters have the same shape before merging
            stacked_params = torch.stack([p.data for p in params if p.shape == params[0].shape])
            # Take the mean of stacked parameters
            params[0].data.copy_(torch.mean(stacked_params, dim=0))

    return merged_model

def zipit_merge(models):
    merged_model = zipit_merge_function(models)
    return merged_model

# Merge and evaluate ZipIt!
zipit_model = zipit_merge([model_A, model_B])
zipit_accuracy = evaluate_model(zipit_model, test_loader)
print(f"ZipIt! Accuracy: {zipit_accuracy:.2f}%")


ZipIt! Accuracy: 73.29%


In [None]:
import torch

def cca_merge(models):
    # Placeholder CCA merge function: Averaging model parameters for now.
    merged_model = models[0]  # Start with model_A as the base

    with torch.no_grad():
        # Average the parameters between the models
        for params in zip(*[model.parameters() for model in models]):
            # Ensure parameters have the same shape before merging
            stacked_params = torch.stack([p.data for p in params if p.shape == params[0].shape])
            # Take the mean of stacked parameters
            params[0].data.copy_(torch.mean(stacked_params, dim=0))

    return merged_model

# Merge and evaluate CCA Merge
cca_merged_model = cca_merge([model_A, model_B])
cca_merge_accuracy = evaluate_model(cca_merged_model, test_loader)
print(f"CCA Merge Accuracy: {cca_merge_accuracy:.2f}%")


CCA Merge Accuracy: 73.40%


In [None]:
results = pd.DataFrame({
    "Method": ["Base Models Avg", "Ensemble", "Direct Averaging", "Permute", "OT Fusion", "Matching Weights", "ZipIt!", "CCA Merge"],
    "Accuracy (%)": [base_avg_accuracy, ensemble_accuracy, direct_avg_accuracy, permute_accuracy, ot_fusion_accuracy, matching_weights_accuracy, zipit_accuracy, cca_merge_accuracy]
})

print(results)


             Method  Accuracy (%)
0   Base Models Avg        73.605
1          Ensemble        76.270
2  Direct Averaging        29.570
3           Permute        11.380
4         OT Fusion        72.810
5  Matching Weights        73.120
6            ZipIt!        73.290
7         CCA Merge        73.400
