In [1]:
import enum
from models.resnet import resnet18im, resnet18
import os, random
import copy
import numpy as np
import argparse
import torch
import torchvision
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from pyhocon import ConfigFactory
from options import Option
from dataset import create_loader
from collections import OrderedDict

from torch.utils.tensorboard import SummaryWriter
from log_utils import *

import pandas as pd
import pickle

In [2]:
option = Option("./cifar100.hocon", "test")
# 기존에 있는 log 를 불러오기 위해 log_override를 false로 하고 진행
option.log_override=False
option.set_save_path()


torch.manual_seed(option.seed)
torch.cuda.manual_seed(option.seed)
np.random.seed(option.seed)



./save_log/resnet18_cifar100/log_cifar100_resnet18_bs128_ep200_seed_5/ is exists
load log path ./save_log/resnet18_cifar100/log_cifar100_resnet18_bs128_ep200_seed_5/


In [12]:
s = torch.randn([3, 4, 5, 6])

s.transpose(0, 1).reshape(4, -1).var(axis=1, unbiased=False)

tensor([1.1470, 1.0576, 0.9212, 1.0043])

In [3]:
class Compare_BatchNorm_Precision(nn.Module):
    # this is custom batchnorm different precision
    # Considered https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py and rangeBN 

    def __init__(self, num_features, dim=1, momentum=0.9, affine=True, eps=1e-5, compute_type=torch.float16):
        super(Compare_BatchNorm_Precision, self).__init__()
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))

        self.momentum = momentum
        self.dim = dim
        self.eps = 1e-10
        if affine:
            self.bias = nn.Parameter(torch.Tensor(num_features))
            self.weight = nn.Parameter(torch.Tensor(num_features))
        self.compute_type = compute_type
        self.eps = eps
        self.reset_params()

    def reset_params(self):
        if self.weight is not None:
            self.weight.data.uniform_()
        if self.bias is not None:
            self.bias.data.zero_()

    def load_params(self, weight, bias, running_mean, running_var):
        self.weight=weight
        self.bias=bias
        self.running_mean = running_mean
        self.running_var = running_var
        

    def forward(self, x, inference=True):
        x_ch = x.type(self.compute_type)        
        if inference:
            mean = self.running_mean
            scale = self.running_var
            mean_ch = self.running_mean.type(self.compute_type)
            scale_ch = self.running_var.type(self.compute_type)

            out = (x - mean.view(1, mean.size(0), 1, 1)) * \
                scale.view(1, scale.size(0), 1, 1)

            out_ch = (x_ch - mean_ch.view(1, mean_ch.size(0), 1, 1)) * \
                scale_ch.view(1, scale_ch.size(0), 1, 1)


        else:
            c = x_ch.shape[1]
            mean = x.transpose(0,1).reshape(c, -1).mean(dim=-1)
            scale = x.transpose(0,1).reshape(c, -1).var(dim=-1, unbiased=False)
            mean_ch = x_ch.transpose(0,1).reshape(c, -1).mean(dim=-1)
            scale_ch = x_ch.transpose(0,1).reshape(c, -1).var(dim=-1, unbiased=False)

            out = (x - mean.view(1, mean.size(0), 1, 1)) * \
                torch.sqrt(scale.view(1, scale.size(0), 1, 1) + self.eps)

            out_ch = (x_ch - mean_ch.view(1, mean_ch.size(0), 1, 1)) * \
                torch.sqrt(scale_ch.view(1, scale.size(0), 1, 1) + self.eps.type(self.compute_type))
        
        if self.weight is not None:
            
            weight = self.weight
            weight_ch = self.weight.type(self.compute_type)
            out = out * weight.view(1, weight.size(0), 1, 1)
            out_ch = out_ch * weight_ch.view(1, weight_ch.size(0), 1, 1)

        if self.bias is not None:
            bias =self.bias
            out = out * bias.view(1, bias.size(0), 1, 1)
            bias_ch = self.bias.type(self.compute_type)
            out_ch = out_ch * bias_ch.view(1, bias_ch.size(0), 1, 1)

        return out, out_ch

In [4]:
batchnorm_csv_path = os.path.join(option.save_path, "batchnorm_param.csv")
batchnorm_df = pd.read_csv(batchnorm_csv_path)


# csv에 저장된 batch_norm 데이터를 읽어오는 함수

def csv_txt_to_param(txt):
    temp_txt = txt
    temp_txt = temp_txt.replace("[", "")
    temp_txt = temp_txt.replace("]", "")
    temp_list = temp_txt.split()

    np_txt = np.array(temp_list, dtype=np.float32)
    torch_result = torch.Tensor(np_txt)
    return torch_result


In [5]:
option.activation_index

[2, 7, 14]

In [6]:
result_df =pd.DataFrame()
for epoch in option.activation_step:
    epoch_df = batchnorm_df[batchnorm_df.epoch == epoch]
    print(epoch)
    for index in option.activation_index:
        alpha_trigger = False
        beta_trigger =False
        avg_trigger = False
        var_trigger =False
        for key in sorted(epoch_df.keys()): # alpha, avg, beta, var and find index
            if str(index) == key.split("_")[0]:
                file_name = '_'.join(key.split("_")[1:-1])
                target_key = key
                target_df = epoch_df[key].iloc[0]
                
                if "alpha" in key:
                    bn_weight = nn.Parameter(csv_txt_to_param(target_df))
                    alpha_trigger=True
                elif "beta" in key :
                    bn_bias = nn.Parameter(csv_txt_to_param(target_df))
                    beta_trigger= True
                elif "avg" in key:
                    bn_running_mean = nn.Parameter(csv_txt_to_param(target_df))
                    avg_trigger = True
                elif "var" in key:
                    bn_running_var = nn.Parameter(csv_txt_to_param(target_df))
                    var_trigger = True
        
        if alpha_trigger and beta_trigger and var_trigger and avg_trigger:

            print("epoch : ", index)
            print("layer name : ", file_name)
            with torch.no_grad():
                
                BatchNorm_fp16_layer = Compare_BatchNorm_Precision(num_features=bn_weight.shape, compute_type=torch.float16)
                BatchNorm_fp16_layer.load_params(bn_weight, bn_bias, bn_running_mean, bn_running_var)
                print("-------------- FP32 - FP16 distance --------------------")
                print(f"bn.weight mse distance {((bn_weight - bn_weight.type(torch.float16))**2).mean()}")
                print(f"bn.bias mse distance {((bn_bias - bn_bias.type(torch.float16))**2).mean()}")
                print(f"bn.avg mse distance {((bn_running_mean - bn_running_mean.type(torch.float16))**2).mean()}")
                print(f"bn.var mse distance {((bn_running_var - bn_running_var.type(torch.float16))**2).mean()}")
                print("-------------- FP32 - FP16 end --------------------")
                

                print("-------------- FP32 - BF16 distance --------------------")
                print(f"bn.weight mse distance {((bn_weight - bn_weight.type(torch.bfloat16))**2).mean()}")
                print(f"bn.bias mse distance {((bn_bias - bn_bias.type(torch.bfloat16))**2).mean()}")
                print(f"bn.avg mse distance {((bn_running_mean - bn_running_mean.type(torch.bfloat16))**2).mean()}")
                print(f"bn.var mse distance {((bn_running_var - bn_running_var.type(torch.bfloat16))**2).mean()}")
                
                print("-------------- FP32 - BF16 end --------------------")
                

                BatchNorm_bf16_layer = Compare_BatchNorm_Precision(num_features=bn_weight.shape, compute_type=torch.bfloat16)
                BatchNorm_bf16_layer.load_params(bn_weight, bn_bias, bn_running_mean, bn_running_var)

            
                input_pkl_file_name = f"idx_{index}_{file_name}_{epoch}_input.pkl"
                input_pkl_file_path = os.path.join(option.save_path, input_pkl_file_name)

                with open(input_pkl_file_path, "rb") as f:
                    input_tensor = pickle.load(f)

                output, output_fp16 = BatchNorm_fp16_layer.forward(input_tensor)
                _, output_bf16 = BatchNorm_bf16_layer.forward(input_tensor)

                C = output.shape[1]

                output_t = output.transpose(0, 1).reshape(C, -1)
                output_fp16_t = output_fp16.transpose(0, 1).reshape(C, -1)
                fp16_mse_dist = ((output_t - output_fp16_t)**2).mean(-1)
                fp16_L_inf = ((output_t - output_fp16_t)**2).abs().max(-1)[0]
                
                output_bf16_t = output_bf16.transpose(0, 1).reshape(C, -1)
                bf16_mse_dist = ((output_t - output_bf16_t)**2).mean(-1)
                bf16_L_inf = ((output_t - output_bf16_t)**2).abs().max(-1)[0]
                
                csv_dict = {"epoch": epoch, "layer": file_name, "FP32-FP16_mse": fp16_mse_dist, "fp32-bf16_mse" : bf16_mse_dist, "FP32-FP16_L_inf":fp16_L_inf, "FP32-bf16_L_inf":bf16_L_inf}
                result_df = result_df.append(csv_dict, ignore_index=True)

result_df.to_csv("./precision_compare.csv", index=False)


30
epoch :  2
layer name :  conv2_x.0.residual_function.4
-------------- FP32 - FP16 distance --------------------
bn.weight mse distance 1.5179169920997992e-08
bn.bias mse distance 1.2109795111125976e-10
bn.avg mse distance 2.539544796675841e-09
bn.var mse distance 1.8388015554648973e-09
-------------- FP32 - FP16 end --------------------
-------------- FP32 - BF16 distance --------------------
bn.weight mse distance 1.1077389672209392e-06
bn.bias mse distance 8.529898565257099e-09
bn.avg mse distance 1.5773014183650957e-07
bn.var mse distance 1.2950454220117535e-07
-------------- FP32 - BF16 end --------------------
epoch :  7
layer name :  conv3_x.0.shortcut.1
-------------- FP32 - FP16 distance --------------------
bn.weight mse distance 1.4793691605063941e-08
bn.bias mse distance 1.378347574965133e-10
bn.avg mse distance 2.9327800188383435e-09
bn.var mse distance 5.129983660090431e-10
-------------- FP32 - FP16 end --------------------
-------------- FP32 - BF16 distance ---------

In [10]:
a = torch.randn(4, 5)
b = torch.randn(5, 4)

c=a @ b

c_fp16=a.type(torch.float16) @ b.type(torch.float16)

c_bf16=a.type(torch.bfloat16) @ b.type(torch.bfloat16)

print(c)
print(((c-c_fp16)*1e+5).type(torch.float32))
print(((c-c_bf16)*1e+5).type(torch.float32))

tensor([[ 3.5620, -1.9661, -3.0356,  0.8310],
        [ 0.6103, -4.3786,  1.1484,  1.2283],
        [-1.9923, -1.5068,  1.3988,  0.7608],
        [-0.6837,  2.0734, -0.2611,  1.1067]])
tensor([[ -53.3104,  -25.1412,  -42.7485,  -10.4547],
        [  -4.5836,  417.6140,   97.5966, -121.7484],
        [  90.8732,    7.7128,   37.9682,    6.2883],
        [  85.6161,  -83.0889,   58.3947,   23.5319]])
tensor([[ 1509.1896, -1294.6725,  1129.1266,  -498.7359],
        [   93.0727,  -363.6360,   781.1904,   952.4703],
        [  774.4670,   105.3691,   819.2181,   -91.3680],
        [  378.5849,  1088.7861,  -332.2303,  -269.4368]])
