In [2]:
import enum
from models.resnet import resnet18im, resnet18
import os, random
import copy
import numpy as np
import argparse
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from pyhocon import ConfigFactory
from options import Option
from dataset import create_loader
from collections import OrderedDict

from torch.utils.tensorboard import SummaryWriter
from log_utils import *

import pandas as pd
import pickle

def accuracy(output, target):
    pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
    return pred.eq(target.view_as(pred)).sum().item() / output.shape[0]


def train(model, train_loader, optimizer, epoch, device):
    model.train()
    total_loss = AverageMeter()
    total_acc = AverageMeter()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.cross_entropy(output, target)
        total_loss.update(loss)
        total_acc.update(accuracy(output, target))
        loss.backward()
        optimizer.step()
        if batch_idx % 1000 == 0:
            print('Train Epoch: [{}] [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    
    print('Train Epoch: [{}]\t Average Loss: {:.6f}\t Total Acc : {:.4f}'.format(
                epoch, total_loss.avg, total_acc.avg * 100))
    
    return total_acc.avg, total_loss.avg

def test(model, test_loader, epoch, device):
    model.eval()
    total_test_loss = AverageMeter()
    correct = AverageMeter()
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss = F.cross_entropy(output, target)
            total_test_loss.update(test_loss)
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct.update((pred.eq(target.view_as(pred)).sum().item()))

    total_test_acc = correct.sum / len(test_loader.dataset)
    print('\nEpoch [{}] Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.4f}%)\n'.format(
        epoch, total_test_loss.avg, correct.sum, len(test_loader.dataset),
        100. * total_test_acc))
    return total_test_acc, total_test_loss.avg

class Hook():
    def __init__(self, module, forward=True):
        if forward:
            self.hook = module.register_forward_hook(self.hook_fn)
        else:
            self.hook = module.register_backward_hook(self.hook_fn)
    
    def hook_fn(self, module, input, output):
        self.input = input
        self.output = output
    def close(self):
        self.hook.remove()



In [3]:
option = Option("./cifar100.hocon", "test")
option.set_save_path()

writer = SummaryWriter(os.path.join(option.save_path, "tfboard_result"))

torch.manual_seed(option.seed)
torch.cuda.manual_seed(option.seed)
np.random.seed(option.seed)

if option.dataset.lower() == "cifar100":
    cifar100_path = os.path.join(option.data_path, "CIFAR100")
    train_loader, test_loader, n_class, image_size = create_loader(option.batch_size, cifar100_path, option.dataset)
    if "resnet18" in option.model_name.lower():
        net = resnet18(pretrained=False, num_classes=n_class)
    
    else:
        raise AssertionError("cifar100 only resnet18 test")


elif option.dataset.lower() == "imagenet":
    train_loader ,test_loader, n_class, image_size = create_loader(option.batch_size, option.data_path, option.dataset)
    if "resnet18" in option.model_name.lower():
        net = resnet18im(pretrained=False, num_classes=n_class)
    else:
        raise AssertionError("imagenet only resnet18 test")

else:
    raise AssertionError("using only dataset imagenet/cifar100")


Files already downloaded and verified
Files already downloaded and verified


In [4]:
criterion = nn.CrossEntropyLoss()

if option.optimizer.lower() == "sgd":
    optimizer = optim.SGD(net.parameters(), lr=option.lr, momentum=option.momentum, weight_decay=option.weight_decay, nesterov=True)

elif option.optimizer.lower() == "adam":
    optimizer = optim.Adam(net.parameters(), lr= option.lr, weight_decay=option.weight_decay)

if option.scheduler.lower() == "multi_step":
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=option.ml_step, gamma=option.lr_gamma)



In [5]:
if option.load_state_dict != False:
    if not os.path.exists(option.load_state_dict):
        raise AssertionError("If using load state dict, please set default path")
    else:
        print(f"load state dict {option.load_state_dict}")
        checkpoint = torch.load(option.load_state_dict)
        start_epoch = checkpoint['end_epoch']+1
        net.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
else:
    print(f"train base line")
    start_epoch = 0

train base line


In [6]:
def hook_result(model, data, option):
    model.eval()
    count = 0
    hook_list = []
    for i, (name, module) in enumerate(model.named_modules()):
        if isinstance(module, nn.BatchNorm2d) and "downsample" not in name:
            if count in option.activation_index:
                temp_hook = Hook(module)
                name_idx = f"idx_{count}_{name}"
                hook_list.append((name_idx, temp_hook))
            count+=1
    with torch.no_grad():
        _ = model(data)
    
    return hook_list

In [7]:
def get_batchnorm_param_dict(net, epoch):
    count = 0
    save_bn_dict = {}
    save_bn_dict['epoch'] = epoch

    for name, m in net.named_modules():
        if isinstance(m, nn.BatchNorm2d) and 'downsample' not in name:
            save_bn_dict[f"{count}_{name}_alpha"] = m.weight.cpu().detach().numpy()
            save_bn_dict[f"{count}_{name}_beta"] = m.bias.cpu().detach().numpy()
            save_bn_dict[f"{count}_{name}_avg"] = m.running_mean.cpu().detach().numpy()
            save_bn_dict[f"{count}_{name}_var"] = m.running_var.cpu().detach().numpy()
            count+=1
    return save_bn_dict


In [8]:
device = torch.device('cuda:0')
net = net.to(device)
random_sampler= torch.utils.data.RandomSampler(test_loader.dataset)
sample_loader = torch.utils.data.DataLoader(test_loader.dataset, batch_size=128, sampler=random_sampler)

check_data, _ = next(iter(sample_loader))
check_data = check_data.to(device)

ones_shape = [option.batch_size, 3, 224, 224] if option.dataset.lower() == "imagenet" else [option.batch_size, 3, 32, 32]
dummy_input = torch.ones(ones_shape).to(device) * 0.1


with torch.no_grad():
    net.eval()
    writer.add_graph(net, dummy_input)

#del dummy_input
#del ones_shape

csv_path = os.path.join(option.save_path, "batchnorm_param.csv")

batchnorm_df = pd.DataFrame()

print(f"-------start batchnorm param logging -----------\n")

save_bn_dict = get_batchnorm_param_dict(net, epoch=-1)
batchnorm_df = batchnorm_df.append(save_bn_dict, ignore_index=True)
batchnorm_df.to_csv(csv_path)

print(f"-------end batchnorm param logging -----------\n")

-------start batchnorm param logging -----------

-------end batchnorm param logging -----------



In [9]:
print(option.conf)

ConfigTree([('save_path', './save_log/resnet18_cifar100'), ('data_path', '/dataset/'), ('dataset', 'cifar100'), ('nGPU', 4), ('GPU', [0, 1, 2, 3]), ('visible_devices', '1'), ('model_name', 'resnet18'), ('worker', 8), ('seed', 0), ('train', True), ('epochs', 150), ('batch_size', 128), ('momentum', 0.9), ('weight_decay', 0.0005), ('optimizer', 'SGD'), ('warmup', 5), ('lr', 0.01), ('scheduler', 'multi_step'), ('ml_step', [30, 60, 90]), ('lr_gamma', 0.2), ('load_state_dict', False), ('log_override', True), ('activation_index', [2, 7, 14]), ('activation_step', [30, 50, 70, 100])])


In [10]:
best_test_acc = 0
best_epoch = 0
save_best_acc_path = os.path.join(option.save_path, "best_checkpoint.pth")

# 추 후 logger 추가

for epoch in range(start_epoch, option.epochs):
    print(f"-------{epoch} epoch start-----------")
    train_acc, train_loss = train(net, train_loader, optimizer, epoch, device)
    writer.add_scalar("Acc/Train", train_acc, epoch)
    writer.add_scalar("Loss/Train", train_loss, epoch)
    print(f"-------{epoch} epoch end  -----------\n")
    print(f"-------{epoch} batch norm parameter logging ---------")
    save_bn_dict = get_batchnorm_param_dict(net, epoch)
    batchnorm_df = batchnorm_df.append(save_bn_dict, ignore_index=True)
    batchnorm_df.to_csv(csv_path)
    print(f"-------{epoch} batch norm parameter log end ---------\n")
    
    if epoch in option.activation_step:
        print(f"-------- logging {epoch} batch layer input tensor ------------------")        
        result_hook_list = hook_result(net, check_data, option)
        for hook in result_hook_list:
            name = hook[0]
            batch_input = hook[1].input[0].cpu().detach()
            save_pkl_path = os.path.join(option.save_path, f"{name}_{epoch}_input.pkl")
            with open(save_pkl_path, "wb") as fw:
                pickle.dump(batch_input, fw)
            hook[1].close()
        print(f"-------- logging end {epoch} --------------------")         
                
    if epoch % 1 == 0: # imagenet epoch % 5 == 0
        print("----- test and print accuracy ------------------")
        test_acc, test_loss=test(net, test_loader, epoch, device)
        writer.add_scalar("Acc/Test", test_acc, epoch)
        writer.add_scalar("Loss/Test", test_loss, epoch)
        print("----- test end -------------------------")
        print("\n")
        print(f"----- save intermediate {epoch} result ------------")
        save_state_dict_path = os.path.join(option.save_path, f"epoch_{epoch}_state_dict.pth")
        save_prev_state_dict_path = os.path.join(option.save_path, f"epoch_{epoch-1}_state_dict.pth")

        if test_acc > best_test_acc:
            print(f"logging best performance {epoch} epoch")
            torch.save(net.state_dict(), save_best_acc_path)
            best_epoch = epoch
            best_test_acc = test_acc
            writer.add_scalar("Best Test Acc", best_test_acc, best_epoch)
        
    torch.save({
        'end_epoch': epoch,
        'model_state_dict' : net.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        'scheduler_state_dict' : scheduler.state_dict()
    }, save_state_dict_path)

    if os.path.exists(save_prev_state_dict_path):
        os.remove(save_prev_state_dict_path)    



-------0 epoch start-----------
Train Epoch: [0]	 Average Loss: 3.889081	 Total Acc : 10.6909
-------0 epoch end  -----------

-------0 batch norm parameter logging ---------
-------0 batch norm parameter log end ---------

-------- logging 0 batch layer input tensor ------------------
-------- logging end 0 --------------------
----- test and print accuracy ------------------

Epoch [0] Test set: Average loss: 3.4966, Accuracy: 1672/10000 (16.7200%)

----- test end -------------------------


----- save intermediate 0 result ------------
logging best performance 0 epoch
-------1 epoch start-----------
Train Epoch: [1]	 Average Loss: 3.337270	 Total Acc : 18.8971
-------1 epoch end  -----------

-------1 batch norm parameter logging ---------
-------1 batch norm parameter log end ---------

----- test and print accuracy ------------------

Epoch [1] Test set: Average loss: 3.1624, Accuracy: 2275/10000 (22.7500%)

----- test end -------------------------


----- save intermediate 1 resu

KeyboardInterrupt: 

In [1]:
a = {"a":"123", "b":"456"}
print(a)

{'a': '123', 'b': '456'}
