In [None]:
import torch
import sys
import argparse
from torchvision.models import ResNet34_Weights, ResNet18_Weights, ResNet50_Weights, VGG16_BN_Weights
sys.path.append("../")
from dataset import imagenet_dataloader
from utils import *
from ResNet import resnet34, resnet18, resnet50
from torchvision.models import mobilenet_v2, regnet_x_8gf, regnet_x_3_2gf, vgg16_bn

import os
from copy import deepcopy
from torchattacks import FGSM

In [None]:
batch_size = 100
seed = 42
model_name = "resnet18"
iter = 10
timestep = 512
save_path = "./tmp"
load_path = None
os.environ['CUDA_VISIBLE_DEVICES'] = "0"

In [None]:
seed_all(seed)
identifier = "imagenet"
if model_name == "resnet34":
    identifier += "_resnet34"
    model = resnet34()
    weight = ResNet34_Weights.IMAGENET1K_V1
    model.load_state_dict(weight.get_state_dict(progress=True))
elif model_name == "resnet18":
    identifier += "_resnet18"
    model = resnet18()
    weight = ResNet18_Weights.IMAGENET1K_V1
    model.load_state_dict(weight.get_state_dict(progress=True))
elif model_name == "resnet50":
    identifier += "_resnet50"
    model = resnet50()
    weight = ResNet50_Weights.IMAGENET1K_V2
    model.load_state_dict(weight.get_state_dict(progress=True))
elif model_name == "vgg16":
    identifier += "_vgg16"
    model = vgg16_bn()
    weight = VGG16_BN_Weights.IMAGENET1K_V1
    model.load_state_dict(weight.get_state_dict(progress=True))
elif model_name == "mobilenet":
    model = mobilenet_v2(weights='IMAGENET1K_V1')
    identifier += "_mobilenet"
else:
    raise AssertionError("No such model!")

identifier += "_iter{}".format(iter)
train_loader, val_loader = imagenet_dataloader("/home/data10T/data/public/ImageNet/", batchsize=batch_size)

In [None]:
model, cnt = convert_ann_to_snn(model)

if load_path is not None:
    sd = torch.load("./results/"+ load_path +".pth")
    flag = 0
    for key in sd.keys():
        if "running_mean" in key:
            flag = 1
            break
    if flag == 0:
        search_fold_and_remove_bn(model)
    model.load_state_dict(sd, strict=False)
else:
    weight_scaling_iter(train_loader, model, "cuda", iter)
    # torch.save(model.state_dict(), './results/%s.pth'%(identifier))

# remove bn
search_fold_and_remove_bn(model)
acc = snn_inference(val_loader, model, "cuda", timestep)
# print(acc[3,7], acc[7,15], acc[15,31])
np.save('./results/%s_t%d_mat.npy'%(identifier, timestep), acc)

In [None]:
acc

In [None]:
# calculate delay time
delay_t = cal_delay_time(train_loader, model, "cuda")
print(delay_t)
acc = np.load('./results/%s_t%d.npy'%(identifier, 511))
# print(acc[delay_t, [7, 15, 31, 63, 127, 255, 511]])
xx = np.array([7, 15, 31, 63, 127, 255, 510]) + 1
for i in xx:
    if i<delay_t+4:
        print(acc[max(i-4, 0), i])
    else:
        print(acc[delay_t, i])