In [1]:
import math
import copy
import random
# PyTorch
import torch
import torchvision
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [2]:
import sys
sys.path.append("../src/")

%load_ext autoreload
%autoreload 2
# Importing our custom module(s)
import permutations

In [3]:
x = torch.randn(1, 64, 224, 224, dtype=torch.float64)

block1 = torch.nn.Sequential(
    torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, bias=True, dtype=torch.float64),
    torch.nn.BatchNorm2d(num_features=64, dtype=torch.float64),
    torch.nn.ReLU(),
    torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=1, bias=True, dtype=torch.float64),
)

block2 = copy.deepcopy(block1)

perm = torch.randperm(block1[0].out_channels)

permutations.permute_conv2d_out_channels(block1[0], perm)
permutations.permute_batchnorm2d(block1[1], perm)
permutations.permute_conv2d_in_channels(block1[3], perm)

out1 = block1(x)
out2 = block2(x)

diff = (out1 - out2).abs()

print(diff.max().item())

1.3322676295501878e-15


In [4]:
x = torch.randn(1, 3, 224, 224)

model1 = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V1)
model2 = copy.deepcopy(model1)

log_num_perms = 0.0

for layer_name in ["layer1", "layer2", "layer3", "layer4"]:
    
    layer = getattr(model1, layer_name)
    
    for block_idx in range(len(layer)):
        
        block = layer[block_idx]
        
        log_num_perms += math.lgamma(block.conv1.out_channels + 1)
        perm = torch.randperm(block.conv1.out_channels)
        permutations.permute_conv2d_out_channels(block.conv1, perm)
        permutations.permute_batchnorm2d(block.bn1, perm)
        permutations.permute_conv2d_in_channels(block.conv2, perm)
        
        log_num_perms += math.lgamma(block.conv1.out_channels + 1)
        perm = torch.randperm(block.conv2.out_channels)
        permutations.permute_conv2d_out_channels(block.conv2, perm)
        permutations.permute_batchnorm2d(block.bn2, perm)
        permutations.permute_conv2d_in_channels(block.conv3, perm)

out1 = torch.nn.functional.softmax(model1(x), dim=1)
out2 = torch.nn.functional.softmax(model2(x), dim=1)

diff = (out1 - out2).abs()

print(diff.max().item())

4.889443516731262e-09


In [5]:
with torch.no_grad():
    
    print(f"log_num_perms: {log_num_perms}")
    
    params1 = torch.nn.utils.parameters_to_vector(model1.parameters())
    params2 = torch.nn.utils.parameters_to_vector(model2.parameters())

    l2_norm = ((params1 - params2)**2).sum()

    print(f"l2_norm: {l2_norm}")

    percent = 100 * torch.sum((params1 != params2).float()).item() / len(params1)
    
    print(f"percent_different: {percent:.2f}%")

log_num_perms: 35325.70319730624
l2_norm: 10667.8818359375
percent_different: 80.85%
