In [1]:
import copy
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

import numpy as np
import torch.nn.functional as F
from BNNBench.backbones.unet import define_G
from BNNBench.backbones.resnet import ResNet101
from BNNBench.backbones.vgg import VGG
from BNNBench.data.paired_data import get_loader_with_dir

In [2]:
def get_model(f):
    m = define_G(1, 1, 64, "unet_256", norm="instance", use_dropout=False)
    m.load_state_dict(torch.load(f))
    return m

def get_loader():
    data_root = "../datasets/cd105/"
    loader = get_loader_with_dir(data_root + "A/train/", data_root + "B/train/", 1024, 64, True)
    return loader

def get_cls_model(f):
    model_cls = lambda: VGG("VGG16", 10)
    net = model_cls().cuda()
    net = torch.nn.DataParallel(net)
    checkpoint = torch.load(f)
    net.load_state_dict(checkpoint['net'])
    return net

def get_cls_loader():
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(
        root='../datasets', train=True, download=False, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=256, shuffle=True, num_workers=2,
    )
    return trainloader

def diff_upper_bound(init_model, first_epoch, last_epoch):
    L2_init = 0
    L2_last = 0
    for w_init, w_first, w_last in zip(init_model.parameters(), first_epoch.parameters(), last_epoch.parameters()):
        dw1 = w_first - w_init
        dw2 = w_last - w_init
        L2_init += torch.sum(dw1 * dw1)
        L2_last += torch.sum(dw2 * dw2)
    print(torch.sqrt(L2_last / L2_init))

loader = get_cls_loader()
init_model = get_cls_model("../checkpoint/vgg16_ln/ckpt_0.pth")
first_epoch = get_cls_model("../checkpoint/vgg16/local0_0_1.pth")
last_epoch = get_cls_model("../checkpoint/vgg16/local0_0_9.pth")
diff_upper_bound(init_model, first_epoch, last_epoch)

tensor(2.4230, device='cuda:0', grad_fn=<SqrtBackward>)


In [3]:
def loss_l1(model):
    model.eval()
    l = 0.0
    for x, y in loader:
        x, y = x.cuda(), y.cuda()
        with torch.no_grad():
            pred = model(x)
        l += F.l1_loss(pred, y).item()
    print(l / len(loader))

def loss_ce(model):
    model.eval()
    l = 0.0
    for x, y in loader:
        x, y = x.cuda(), y.cuda()
        with torch.no_grad():
            pred = model(x)
        l += F.cross_entropy(pred, y)
    print(l / len(loader))

In [4]:
def extrapolate(lam):
    model_ext = copy.deepcopy(init_model)
    for w_ext, w in zip(model_ext.parameters(), first_epoch.parameters()):
        w_ext.data.add_(lam * (w.data - w_ext.data))
    for (n, w_ext), (n, w) in zip(model_ext.named_buffers(), first_epoch.named_buffers()):
        if n.endswith("num_batches_tracked"):
            continue
        w_ext.data.add_(lam * (w.data - w_ext.data))
    loss_ce(model_ext)
    return model_ext

for lam in np.arange(0, 5, 1):
    print(lam)
    model_ext = extrapolate(float(lam))
    torch.save(model_ext.state_dict(), f"../checkpoint/vgg16_ln/extrapolate_local0_0_1_lam{lam}")


0


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


tensor(0.0004, device='cuda:0')
1
tensor(0.0545, device='cuda:0')
2
tensor(0.5925, device='cuda:0')
3
tensor(10.6671, device='cuda:0')
4
tensor(23.2760, device='cuda:0')


In [8]:
[k for k, v in model_ext.named_parameters()]

['module.features.0.weight',
 'module.features.0.bias',
 'module.features.1.weight',
 'module.features.1.bias',
 'module.features.3.weight',
 'module.features.3.bias',
 'module.features.4.weight',
 'module.features.4.bias',
 'module.features.7.weight',
 'module.features.7.bias',
 'module.features.8.weight',
 'module.features.8.bias',
 'module.features.10.weight',
 'module.features.10.bias',
 'module.features.11.weight',
 'module.features.11.bias',
 'module.features.14.weight',
 'module.features.14.bias',
 'module.features.15.weight',
 'module.features.15.bias',
 'module.features.17.weight',
 'module.features.17.bias',
 'module.features.18.weight',
 'module.features.18.bias',
 'module.features.20.weight',
 'module.features.20.bias',
 'module.features.21.weight',
 'module.features.21.bias',
 'module.features.24.weight',
 'module.features.24.bias',
 'module.features.25.weight',
 'module.features.25.bias',
 'module.features.27.weight',
 'module.features.27.bias',
 'module.features.28.weigh