In [1]:
import sys
import os
sys.path.insert(0,'..')

In [2]:
import numpy as np
import pandas as pd
import numpy.random as npr
import torch
import torch.nn as nn
from data_utils import *
from relu import *
from maxpool import *
from models import *
from data_utils import *
import torch.optim as optim
import torch.nn.init as init
import torchvision
from torchvision import datasets, transforms
import torch.backends.cudnn as cudnn
from tqdm import tqdm
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
%matplotlib inline

In [3]:
deterministic = False
precision = 32
train_loader, test_loader = get_cifar10_loaders()

Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    cudnn.benchmark = False
    # Deterministic convolutions
    if deterministic:
        torch.backends.cudnn.deterministic = True
    else : 
        torch.backends.cudnn.deterministic = False

In [5]:
def init_weights(model):
    for m in model.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='relu')
            #init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.normal_(m.weight, 0, 0.01)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
            
def build_models():
    net1 = VGG('VGG11',maxpool_fn= lambda: MaxPool2DBeta(0), relu_fn= lambda: ReLUAlpha(0))
    net2 = VGG('VGG11',maxpool_fn= lambda: MaxPool2DBeta(1), relu_fn= lambda: ReLUAlpha(0))
    net3 = VGG('VGG11',maxpool_fn= lambda: MaxPool2DBeta(10), relu_fn= lambda: ReLUAlpha(0))
    net4 = VGG('VGG11',maxpool_fn= lambda: MaxPool2DBeta(100), relu_fn= lambda: ReLUAlpha(0))
    net5 = VGG('VGG11',maxpool_fn= lambda: MaxPool2DBeta(1000), relu_fn= lambda: ReLUAlpha(0))
    net6 = VGG('VGG11',maxpool_fn= lambda: MaxPool2DBeta(10000), relu_fn= lambda: ReLUAlpha(0))
    net7 = VGG('VGG11',maxpool_fn= lambda: MaxPool2DBeta(0), relu_fn= lambda: ReLUAlpha(0))
    
    init_weights(net1)
    net2.load_state_dict(net1.state_dict())
    net3.load_state_dict(net1.state_dict())
    net4.load_state_dict(net1.state_dict())
    net5.load_state_dict(net1.state_dict())
    net6.load_state_dict(net1.state_dict())
    net7.load_state_dict(net1.state_dict())
    for net in [net1, net2, net3, net4, net5, net6, net7]:
        net.to(device)
        if precision == 16:
            net.half()
        if precision == 64:
            net.double()

    return net1, net2, net3, net4, net5, net6, net7

def compute_norms(model1, model2):
    diff_L1 = 0
    for p1, p2 in zip(model1.parameters(), model2.parameters()):
        grad1 = p1
        grad2 = p2 
        grad_diff = grad1 - grad2
        
        diff_L1 += torch.norm(grad_diff,1).sum()
        
    return diff_L1.item()

In [6]:
net1, net2, net3, net4, net5, net6, net7 = build_models()
nets = [net1, net2, net3, net4, net5, net6, net7]

optimizers = []
lr = 0.01 * (npr.random(1)[0]/5.0+1) 
for net in nets:
    optimizers.append(torch.optim.SGD(net.parameters(), lr=lr))
    
criterion = nn.CrossEntropyLoss()

In [7]:
def compute_accuracy(net):
    with torch.no_grad():
        total = 0
        correct = 0
        for data in test_loader:
            x, y = data
            x = x.to(device)
            output = net(x)
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total += 1

    return round(correct/total, 9) 

In [None]:
n_epochs = 200
for net in nets:
    net.train()

## Difference between weights
difference12 = []
difference13 = []
difference14 = []
difference15 = []
difference16 = []
difference17 = []

accuracy1 = []
accuracy2 = []
accuracy3 = []
accuracy4 = []
accuracy5 = []
accuracy6 = []

bar = tqdm(range(n_epochs), desc='epoch', leave=False)
for epoch in bar:
    t = tqdm(iter(train_loader), desc="batch_loop", leave=False)
    difference12.append(compute_norms(net1,net2))
    difference13.append(compute_norms(net1,net3))
    difference14.append(compute_norms(net1,net4))
    difference15.append(compute_norms(net1,net5)) 
    difference16.append(compute_norms(net1,net6)) 
    difference17.append(compute_norms(net1,net7)) 
    accuracy1.append(compute_accuracy(net1))
    accuracy2.append(compute_accuracy(net2))
    accuracy3.append(compute_accuracy(net3))
    accuracy4.append(compute_accuracy(net4))
    accuracy5.append(compute_accuracy(net5))
    accuracy6.append(compute_accuracy(net6))
    for inputs, targets in t:
        for i in range(len(nets)):
            inputs = inputs.to(device)
            if precision == 16:
                inputs = inputs.half()
            if precision == 64:
                inputs = inputs.double()
            targets = targets.to(device)
            outputs = nets[i](inputs)
            loss = criterion(outputs, targets)
            optimizers[i].zero_grad()
            loss.backward()
                
        ## Perform an optimization step for each network
        for i in range(len(nets)):
            optimizers[i].step()
    

epoch:   0%|          | 0/200 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

batch_loop:   0%|          | 0/391 [00:00<?, ?it/s]

In [None]:
plt.figure(num=None, figsize=(6, 4), dpi=80, facecolor='w', edgecolor='k')
plt.plot(accuracy1[0:], label = r'$\beta=0$')
plt.plot(accuracy2[0:], label = r'$\beta=1$')
plt.plot(accuracy3[0:], label = r'$\beta=10$')
plt.plot(accuracy4[0:], label = r'$\beta=10^2$')
plt.plot(accuracy5[0:], label = r'$\beta=10^3$')
plt.plot(accuracy6[0:], label = r'$\beta=10^4$')

plt.xlabel("Epoch")
plt.ylabel("Test accuracy")
plt.ylim(0,0.9)
plt.legend()

plt.savefig('figures/accuracy_32bits.pdf', bbox_inches='tight')

In [None]:
plt.figure(num=None, figsize=(6, 4), dpi=80, facecolor='w', edgecolor='k')
plt.plot(difference17[0:], label =  r'0 vs 0')
plt.plot(difference12[0:], label = r'0 vs 1')
plt.plot(difference13[0:], label = r'0 vs 10')
plt.plot(difference14[0:], label = r'0 vs $10^2$')
plt.plot(difference15[0:], label = r'0 vs $10^3$')
plt.plot(difference16[0:], label = r'0 vs $10^4$')

plt.xlabel("Epoch")
plt.ylabel("Weight difference ")
plt.yscale('symlog')
#plt.xlim(0,n_epochs-1)
plt.legend()

plt.savefig('figures/diff_32bits.pdf', bbox_inches='tight')