In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from ofa.utils import AverageMeter, get_net_device, DistributedTensor
from ofa.imagenet_classification.elastic_nn.modules.dynamic_op import DynamicBatchNorm2d
from ofa.imagenet_classification.elastic_nn.networks import OFAResNets
from MemSE import ROOT
from MemSE.dataset import get_dataset, get_transforms, NORM_COEFS, get_dataloader
from MemSE.nn import MemSE, MEMSE_MAP

dataset, num_classes, input_shape = get_dataset('CIFAR10')
transform_train, transform_test = get_transforms(NORM_COEFS.get('ImageNet'))
dataset = dataset(root=f'{ROOT}/data', train=True, download=True, transform=transform_train)
train_loader, train_clean_loader, test_loader, nclasses, input_shape = get_dataloader('CIFAR10')

def evaluate_ofa_subnet(
    ofa_net, path, net_config, data_loader, batch_size, device="cuda:0"
):
    assert "ks" in net_config and "d" in net_config and "e" in net_config
    assert (
        len(net_config["ks"]) == 20
        and len(net_config["e"]) == 20
        and len(net_config["d"]) == 5
    )
    ofa_net.set_active_subnet(ks=net_config["ks"], d=net_config["d"], e=net_config["e"])
    subnet = ofa_net.get_active_subnet().to(device)
    calib_bn(subnet, path, net_config["r"][0], batch_size)
    top1 = validate(subnet, path, net_config["r"][0], data_loader, batch_size, device)
    return top1


def preprocess_ofa_subnet_for_memse(
    ofa_net, dataset, batch_size, input_shape, sigma: float, N: int, device="cuda:0"
):
    ofa_net.sample_active_subnet()
    subnet = ofa_net.get_active_subnet().to(device)
    calib_bn(subnet, dataset, batch_size)
    memse = MemSE(subnet, MEMSE_MAP).to(device)
    return memse, subnet


def calib_bn(net, dataset, batch_size, num_images=2000):
    # print('Creating dataloader for resetting BN running statistics...')
    chosen_indexes = np.random.choice(list(range(len(dataset))), num_images)
    sub_sampler = torch.utils.data.sampler.SubsetRandomSampler(chosen_indexes)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        sampler=sub_sampler,
        batch_size=batch_size,
        num_workers=16,
        pin_memory=True,
        drop_last=False,
    )
    # print('Resetting BN running statistics (this may take 10-20 seconds)...')
    set_running_statistics(net, data_loader)
    
    
def set_running_statistics(model, data_loader, distributed=False):
    bn_mean = {}
    bn_var = {}

    forward_model = copy.deepcopy(model)
    for name, m in forward_model.named_modules():
        if isinstance(m, nn.BatchNorm2d):
            if distributed:
                bn_mean[name] = DistributedTensor(name + "#mean")
                bn_var[name] = DistributedTensor(name + "#var")
            else:
                bn_mean[name] = AverageMeter()
                bn_var[name] = AverageMeter()

            def new_forward(bn, mean_est, var_est):
                def lambda_forward(x):
                    batch_mean = (
                        x.mean(0, keepdim=True)
                        .mean(2, keepdim=True)
                        .mean(3, keepdim=True)
                    )  # 1, C, 1, 1
                    batch_var = (x - batch_mean) * (x - batch_mean)
                    batch_var = (
                        batch_var.mean(0, keepdim=True)
                        .mean(2, keepdim=True)
                        .mean(3, keepdim=True)
                    )

                    batch_mean = torch.squeeze(batch_mean)
                    batch_var = torch.squeeze(batch_var)

                    mean_est.update(batch_mean.data, x.size(0))
                    var_est.update(batch_var.data, x.size(0))

                    # bn forward using calculated mean & var
                    _feature_dim = batch_mean.size(0)
                    return F.batch_norm(
                        x,
                        batch_mean,
                        batch_var,
                        bn.weight[:_feature_dim],
                        bn.bias[:_feature_dim],
                        False,
                        0.0,
                        bn.eps,
                    )

                return lambda_forward

            m.forward = new_forward(m, bn_mean[name], bn_var[name])

    if len(bn_mean) == 0:
        # skip if there is no batch normalization layers in the network
        return

    with torch.no_grad():
        DynamicBatchNorm2d.SET_RUNNING_STATISTICS = True
        for images, labels in data_loader:
            images = images.to(get_net_device(forward_model))
            forward_model(images)
        DynamicBatchNorm2d.SET_RUNNING_STATISTICS = False

    for name, m in model.named_modules():
        if name in bn_mean and bn_mean[name].count > 0:
            feature_dim = bn_mean[name].avg.size(0)
            assert isinstance(m, nn.BatchNorm2d)
            m.running_mean.data[:feature_dim].copy_(bn_mean[name].avg)
            m.running_var.data[:feature_dim].copy_(bn_var[name].avg)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [2]:
from MemSE.fx import fuse, remove_dropout


device = "cuda:0"
ofa_net = OFAResNets().to(device)
print(dataset)
memse_net, ofa_subnet = preprocess_ofa_subnet_for_memse(
    ofa_net, dataset, 32, input_shape, 0.01, 1e6, device=device
)

fused = fuse(ofa_subnet)
fused = remove_dropout(fused)
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
res_prefuse = ofa_subnet(x)
res_ofa = fused(x)
res_memse = memse_net(x)
print(res_prefuse)
print(res_ofa)
print(res_memse)
print(torch.mean(torch.square(res_memse - res_ofa)))
assert torch.all(torch.eq(res_memse, res_ofa))

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: /home/sebwood/workspace/MemSE/data
    Split: Train
    StandardTransform
Transform: Compose(
               RandomCrop(size=(32, 32), padding=4)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
           )




opcode         name                           target                         args                                               kwargs
-------------  -----------------------------  -----------------------------  -------------------------------------------------  ------------------
placeholder    x                              x                              ()                                                 {}
call_module    input_stem_0_conv              input_stem.0.conv              (x,)                                               {}
call_module    input_stem_0_act               input_stem.0.act               (input_stem_0_conv,)                               {}
call_module    input_stem_1_conv_conv         input_stem.1.conv.conv         (input_stem_0_act,)                                {}
call_module    input_stem_1_conv_act          input_stem.1.conv.act          (input_stem_1_conv_conv,)                          {}
call_function  add                            <built-in functio

AssertionError: 