In [9]:
import enum
from models.resnet import resnet18im, resnet18
import os, random
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from options import Option

from log_utils import *

import pandas as pd
import pickle

In [10]:
option = Option("./resnet18_cifar100.hocon", "test")
# 기존에 있는 log 를 불러오기 위해 log_override를 false로 하고 진행
option.log_override=False
option.set_save_path()


torch.manual_seed(option.seed)
torch.cuda.manual_seed(option.seed)
np.random.seed(option.seed)

option.print_parameters()

/data/nsh/save_log/cifar100_resnet18/log_cifar100_resnet18_bs128_ep200_seed_3/ is exists
load log path /data/nsh/save_log/cifar100_resnet18/log_cifar100_resnet18_bs128_ep200_seed_3/
GPU : [0, 1]
activation_index : [3, 9, 15]
activation_step : [0, 30, 50, 70, 100]
batch_size : 128
data_path : /dataset/
dataset : cifar100
epochs : 200
get_weight_grad_param : True
get_weight_param : True
load_state_dict : False
log_override : True
lr : 0.1
lr_gamma : 0.2
ml_step : [60, 120, 160]
model_name : resnet18
momentum : 0.9
nGPU : 1
nesterov : True
optimizer : SGD
save_path : /data/nsh/save_log/cifar100_resnet18
scheduler : multi_step
seed : 3
train : True
visible_devices : 1
warmup : 5
weight_decay : 0.0005
worker : 8


In [11]:
str(torch.float16)

'torch.float16'

In [12]:
# This TensorChannelAdder (N, C, H, W) -> (1, C, 1, 1)
# It is similar result tensor.sum(axis=0).sum(axis=2).sum(axis=3)
# We change tensor.transpose(0, 1).reshape(C, chunk, N*H*W//chunk).sum(axis=2).sum(axis=1)
def BatchNormMeanSim(tensor, chunk=1024,AdderType=torch.float16):
    epsilon = 1e-10
    if not len(tensor.shape) ==4 :
        AssertionError(f"It only supported 4d Matrix, but this tensor shape {tensor.shape}")

    if AdderType==torch.float16:
        mantissa = 10
    elif AdderType==torch.bfloat16:
        mantissa = 7
    elif AdderType==torch.float32:
        mantissa = 22
    elif AdderType=="test":
        mantissa= 100
    else:
        AssertionError("This Adder only supported FP16|BF16|FP32")
    
    temp_tensor =tensor.clone()
    zero_mask_counter = []
    
    n, c, h, w = tensor.shape

    
    if not n*h*w % chunk == 0: 
        AssertionError(f"The n*h*w should always be divisible chunk but result {n*h*w % chunk}")
    #change (c, n*h*w//chunk, chunk) 
    chunk_tensor = temp_tensor.transpose(1, 0).reshape(c, chunk, n*h*w//chunk)
    
    # first chunk based Adder (last dim size is equal to n*h*w divided by chunk, so last dim adder is always chunk adder)
    # (C, chunk, 0) + (C, chunk, 1) = C*chunk adder
    # accumulated that result of adder is final values (chunk_tensor[:, :, -1])
    for i in range(chunk_tensor.shape[-1] -1):
        prev = chunk_tensor[:, :, i]
        prec = chunk_tensor[:, :, i+1]
        log_prev = torch.log2(torch.abs(prev)+epsilon)
        log_prec = torch.log2(torch.abs(prec)+epsilon)
        zero_mask = torch.abs(log_prec-log_prev) > mantissa
        max_log_tensor = prec.clone()
        max_log_tensor[log_prec<log_prev] = prev[log_prec<log_prev] # 두 벡터 중 log2 의 value가 큰 값을 가지고 있는 vector 생성
        output = prec+prev # 두 벡터를 더함
        output[zero_mask] = max_log_tensor[zero_mask] # zero_mask에 해당하는 부분은 log2 value가 큰 값만 저장
        chunk_tensor[:,:, i+1] = output
        zero_mask_counter.append(zero_mask.sum())
    
    sum_tensor = chunk_tensor[:, :, -1] # C, chunk_size
    #print(f"chunk based sum result : {sum(zero_mask_counter)}/{c * chunk * (chunk_tensor.shape[-1]-1)} = {sum(zero_mask_counter) / (c * chunk * (chunk_tensor.shape[-1]-1)) * 100}%")
    
    for j in range(chunk_tensor.shape[1]-1):
        prev = sum_tensor[:, j]
        prec = sum_tensor[:, j+1]        
        log_prev = torch.log2(torch.abs(prev)+epsilon)
        log_prec = torch.log2(torch.abs(prec)+epsilon)
        zero_mask = torch.abs(log_prec-log_prev) > mantissa
        max_log_tensor = prec.clone()
        max_log_tensor[log_prec<log_prev] = prev[log_prec<log_prev] # 두 벡터 중 log2 의 value가 큰 값을 가지고 있는 vector 생성
        output = prec+prev # 두 벡터를 더함
        output[zero_mask] = max_log_tensor[zero_mask] # zero_mask에 해당하는 부분은 log2 value가 큰 값만 저장
        sum_tensor[:, j+1] = output
        zero_mask_counter.append(zero_mask.sum())
        
    # chunk_tensor[:, :, -1] is same to chunk_tensor.sum(dim=-2), and then finally  

    #print(f"final sum result : {sum(zero_mask_counter)}/{c*(chunk-1)+c * chunk * (chunk_tensor.shape[-1]-1)} =\
    #    {sum(zero_mask_counter)/(c*(chunk-1)+ c * chunk * (chunk_tensor.shape[-1]-1)) * 100}%")
    
    return sum_tensor[:, -1]/(n*h*w), sum(zero_mask_counter), (n*h*w-1)*c

In [13]:
# This TensorChannelAdder (N, C, H, W) -> (1, C, 1, 1)
# It is similar result tensor.sum(axis=0).sum(axis=2).sum(axis=3)
# We change tensor.transpose(0, 1).reshape(C, chunk, N*H*W//chunk).sum(axis=2).sum(axis=1)
def BatchNormStdSim(tensor, mean_tensor, chunk=1024, AdderType=torch.float16):
    epsilon = 1e-10
    if not len(tensor.shape) ==4 :
        AssertionError(f"It only supported 4d Matrix, but this tensor shape {tensor.shape}")

    if AdderType==torch.float16:
        mantissa = 10
    elif AdderType==torch.bfloat16:
        mantissa = 7
    elif AdderType==torch.float32:
        mantissa = 22
    elif AdderType=="test":
        mantissa= 100
    else:
        AssertionError("This Adder only supported FP16|BF16|FP32")
    
    temp_tensor =tensor.clone()
    zero_mask_counter = []
    
    n, c, h, w = tensor.shape

    
    if not n*h*w % chunk == 0: 
        AssertionError(f"The n*h*w should always be divisible chunk but result {n*h*w % chunk}")
    #change (c, n*h*w//chunk, chunk) 

    if mean_tensor.dim() == 1:
        #change 4d tensor
        mean_tensor = mean_tensor.reshape(1, -1, 1, 1)
    elif mean_tensor.dim() !=4:
        AssertionError("mean_tensor input only 1d or 4d tensor")
    
    if not mean_tensor.shape[1] == c:
        AssertionError("mean_tensor and tensor is required same shape")
    # first computing (X-mean)**2

    mean_tensor = torch.zeros_like(temp_tensor) + mean_tensor # broadcasting and same shape result tensor
    log_temp_tensor = torch.log2(torch.abs(temp_tensor) + epsilon)
    log_temp_mean = torch.log2(torch.abs(mean_tensor) + epsilon)
    zero_mask = torch.abs(log_temp_tensor - log_temp_mean) > mantissa
    output = temp_tensor - mean_tensor # X - mean(X)
    max_log_tensor = temp_tensor.clone()
    max_log_tensor[log_temp_tensor<log_temp_mean] = mean_tensor[log_temp_tensor<log_temp_mean] # get log2 max_value
    output[zero_mask]=max_log_tensor[zero_mask]
    var = output**2 # (X - mean(X))^2

    chunk_tensor = var.transpose(1, 0).reshape(c, chunk, n*h*w//chunk)

    
    # second chunk based Adder (last dim size is equal to n*h*w divided by chunk, so last dim adder is always chunk adder)
    # (C, chunk, 0) + (C, chunk, 1) = C*chunk adder
    # accumulated that result of adder is final values (chunk_tensor[:, :, -1])
    for i in range(chunk_tensor.shape[-1] -1):
        prev = chunk_tensor[:, :, i]
        prec = chunk_tensor[:, :, i+1]
        log_prev = torch.log2(torch.abs(prev)+epsilon)
        log_prec = torch.log2(torch.abs(prec)+epsilon)
        zero_mask = torch.abs(log_prec-log_prev) > mantissa
        max_log_tensor = prec.clone()
        max_log_tensor[log_prec<log_prev] = prev[log_prec<log_prev] # 두 벡터 중 log2 의 value가 큰 값을 가지고 있는 vector 생성
        output = prec+prev # 두 벡터를 더함
        output[zero_mask] = max_log_tensor[zero_mask] # zero_mask에 해당하는 부분은 log2 value가 큰 값만 저장
        chunk_tensor[:,:, i+1] = output
        zero_mask_counter.append(zero_mask.sum())
    
    sum_tensor = chunk_tensor[:, :, -1] # C, chunk_size
    #print(f"chunk based sum result : {sum(zero_mask_counter)}/{c * chunk * (chunk_tensor.shape[-1]-1)} = {sum(zero_mask_counter) / (c * chunk * (chunk_tensor.shape[-1]-1)) * 100}%")
    
    for j in range(chunk_tensor.shape[1]-1):
        prev = sum_tensor[:, j]
        prec = sum_tensor[:, j+1]        
        log_prev = torch.log2(torch.abs(prev)+epsilon)
        log_prec = torch.log2(torch.abs(prec)+epsilon)
        zero_mask = torch.abs(log_prec-log_prev) > mantissa
        max_log_tensor = prec.clone()
        max_log_tensor[log_prec<log_prev] = prev[log_prec<log_prev] # 두 벡터 중 log2 의 value가 큰 값을 가지고 있는 vector 생성
        output = prec+prev # 두 벡터를 더함
        output[zero_mask] = max_log_tensor[zero_mask] # zero_mask에 해당하는 부분은 log2 value가 큰 값만 저장
        sum_tensor[:, j+1] = output
        zero_mask_counter.append(zero_mask.sum())
        
    # chunk_tensor[:, :, -1] is same to chunk_tensor.sum(dim=-2), and then finally  

    #print(f"final sum result : {sum(zero_mask_counter)}/{c*(chunk-1)+c * chunk * (chunk_tensor.shape[-1]-1)} =\
    #    {sum(zero_mask_counter)/(c*(chunk-1)+ c * chunk * (chunk_tensor.shape[-1]-1)) * 100}%")
    
    return torch.sqrt(sum_tensor[:, -1]/(n*h*w)), sum(zero_mask_counter), (n*h*w-1)*c

In [22]:
test_tensor = torch.randn(128, 32, 16, 16)
original_mean = test_tensor.transpose(0,1).reshape(32, -1).mean(1)
refined_mean = BatchNormMeanSim(test_tensor, chunk=1024, AdderType="test") # zero_setting error 발생안함.
print(original_mean)
print(refined_mean)

original_std = test_tensor.transpose(0,1).reshape(32,-1).std(1)
refined_std =BatchNormStdSim(test_tensor, original_mean, chunk=1024, AdderType="test")
print(original_std)
print(refined_std)

chunk based sum result : 0/1015808 = 0.0%
tensor([-0.0078, -0.0006, -0.0035, -0.0039, -0.0056,  0.0142,  0.0017,  0.0023,
         0.0043, -0.0045, -0.0013, -0.0046,  0.0023, -0.0019,  0.0101, -0.0036,
         0.0102,  0.0052,  0.0104, -0.0016,  0.0012,  0.0014,  0.0037,  0.0004,
        -0.0075,  0.0056,  0.0108, -0.0060,  0.0024,  0.0025, -0.0106, -0.0031])
(tensor([-0.0078, -0.0006, -0.0035, -0.0039, -0.0056,  0.0142,  0.0017,  0.0023,
         0.0043, -0.0045, -0.0013, -0.0046,  0.0023, -0.0019,  0.0101, -0.0036,
         0.0102,  0.0052,  0.0104, -0.0016,  0.0012,  0.0014,  0.0037,  0.0004,
        -0.0075,  0.0056,  0.0108, -0.0060,  0.0024,  0.0025, -0.0106, -0.0031]), tensor(0), 1048544)
chunk based sum result : 0/1015808 = 0.0%
tensor([1.0005, 1.0002, 0.9954, 1.0013, 1.0012, 0.9972, 0.9992, 0.9966, 1.0012,
        1.0026, 1.0054, 1.0032, 0.9957, 1.0012, 1.0064, 1.0021, 1.0002, 0.9988,
        0.9989, 0.9955, 0.9992, 0.9890, 1.0029, 1.0000, 1.0011, 1.0081, 0.9985,
        0.99

In [14]:
class Compare_BatchNorm_Precision(nn.Module):
    # this is custom batchnorm different precision
    # Considered https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py and rangeBN 


    def __init__(self, num_features, intermediate_result = False, chunk=1024, dim=1, momentum=0.9, affine=True, eps=1e-5, compute_type=torch.float16):
        super(Compare_BatchNorm_Precision, self).__init__()
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))

        self.momentum = momentum
        self.dim = dim
        self.eps = 1e-10
        self.intermediate_result = intermediate_result
        if affine:
            self.bias = nn.Parameter(torch.Tensor(num_features))
            self.weight = nn.Parameter(torch.Tensor(num_features))
        self.compute_type = compute_type
        self.eps = eps
        self.chunk = chunk
        self.reset_params()

    def reset_params(self):
        if self.weight is not None:
            self.weight.data.uniform_()
        if self.bias is not None:
            self.bias.data.zero_()

    def load_params(self, weight, bias, running_mean, running_var):
        self.weight=weight
        self.bias=bias
        self.running_mean = running_mean
        self.running_var = running_var
        

    def forward(self, x, inference=True):
        x_ch = x.type(self.compute_type)        
        if not inference:
        
            # not using momentum
            n,c,h,w = x.shape
            mean = x.transpose(1,0).reshape(c,-1).mean(dim=-1) # c axis mean
            std = x.transpose(1,0).reshape(c,-1).std(dim=-1) # c axis std

            mean_ch, zero_mean_count, total_mean_compute =  BatchNormMeanSim(x, chunk=self.chunk, AdderType=self.compute_type)
            std_ch, zero_std_count, total_std_compute = BatchNormStdSim(x, mean_ch, chunk=self.chunk, AdderType=self.compute_type)

            out = (x - mean.view(1, mean.size(0), 1, 1)) / \
                (std.view(1, std.size(0), 1, 1) + self.eps)

            out_ch = (x_ch - mean_ch.view(1, mean_ch.size(0), 1, 1)) / \
                (std_ch.view(1, std_ch.size(0), 1, 1) + self.eps)


        else:
            # using running_mean, and running_var
            c = x_ch.shape[1]
            mean = self.running_mean
            scale = self.running_var
            mean_ch = self.running_mean.type(self.compute_type)
            scale_ch = self.running_var.type(self.compute_type)

            out = (x - mean.view(1, mean.size(0), 1, 1)) / \
                (scale.view(1, scale.size(0), 1, 1) + self.eps)

            out_ch = (x_ch - mean_ch.view(1, mean_ch.size(0), 1, 1)) / \
                (scale_ch.view(1, scale.size(0), 1, 1) + self.eps)
        
        if self.intermediate_result:
            # return only normalized value

            return out, out_ch, zero_mean_count, total_mean_compute, zero_std_count, total_std_compute

        if self.weight is not None:
            
            weight = self.weight
            weight_ch = self.weight.type(self.compute_type)
            out = out * weight.view(1, weight.size(0), 1, 1)
            out_ch = out_ch * weight_ch.view(1, weight_ch.size(0), 1, 1)

        if self.bias is not None:
            bias =self.bias
            out = out + bias.view(1, bias.size(0), 1, 1)
            bias_ch = self.bias.type(self.compute_type)
            out_ch = out_ch + bias_ch.view(1, bias_ch.size(0), 1, 1)

        return out, out_ch, zero_mean_count, total_mean_compute, zero_std_count, total_std_compute


In [15]:
batchnorm_csv_path = os.path.join(option.save_path, "batchnorm_param.csv")
batchnorm_df = pd.read_csv(batchnorm_csv_path)


# csv에 저장된 batch_norm 데이터를 읽어오는 함수

def csv_txt_to_param(txt):
    temp_txt = txt
    temp_txt = temp_txt.replace("[", "")
    temp_txt = temp_txt.replace("]", "")
    temp_list = temp_txt.split()

    np_txt = np.array(temp_list, dtype=np.float32)
    torch_result = torch.Tensor(np_txt)
    return torch_result


In [16]:
result_df =pd.DataFrame()
for epoch in option.activation_step:
    epoch_df = batchnorm_df[batchnorm_df.epoch == epoch]
    print(epoch)
    for index in option.activation_index:
        alpha_trigger = False
        beta_trigger =False
        avg_trigger = False
        var_trigger =False
        for key in sorted(epoch_df.keys()): # alpha, avg, beta, var and find index
            if str(index) == key.split("_")[0]:
                file_name = '_'.join(key.split("_")[1:-1])
                target_key = key
                target_df = epoch_df[key].iloc[0]
                
                if "alpha" in key:
                    bn_weight = nn.Parameter(csv_txt_to_param(target_df))
                    alpha_trigger=True
                elif "beta" in key :
                    bn_bias = nn.Parameter(csv_txt_to_param(target_df))
                    beta_trigger= True
                elif "avg" in key:
                    bn_running_mean = nn.Parameter(csv_txt_to_param(target_df))
                    avg_trigger = True
                elif "var" in key:
                    bn_running_var = nn.Parameter(csv_txt_to_param(target_df))
                    var_trigger = True
        
        if alpha_trigger and beta_trigger and var_trigger and avg_trigger:

            print("idx : ", index)
            print("layer name : ", file_name)
            with torch.no_grad():
                
                BatchNorm_fp16_layer = Compare_BatchNorm_Precision(num_features=bn_weight.shape, intermediate_result=True, chunk=1024, compute_type=torch.float16)
                BatchNorm_fp16_layer.load_params(bn_weight, bn_bias, bn_running_mean, bn_running_var)
                """ print("-------------- FP32 - FP16 distance --------------------")
                print(f"bn.weight mse distance {((bn_weight - bn_weight.type(torch.float16))**2).mean()}")
                print(f"bn.bias mse distance {((bn_bias - bn_bias.type(torch.float16))**2).mean()}")
                print(f"bn.avg mse distance {((bn_running_mean - bn_running_mean.type(torch.float16))**2).mean()}")
                print(f"bn.var mse distance {((bn_running_var - bn_running_var.type(torch.float16))**2).mean()}")
                print("-------------- FP32 - FP16 end --------------------")
                

                print("-------------- FP32 - BF16 distance --------------------")
                print(f"bn.weight mse distance {((bn_weight - bn_weight.type(torch.bfloat16))**2).mean()}")
                print(f"bn.bias mse distance {((bn_bias - bn_bias.type(torch.bfloat16))**2).mean()}")
                print(f"bn.avg mse distance {((bn_running_mean - bn_running_mean.type(torch.bfloat16))**2).mean()}")
                print(f"bn.var mse distance {((bn_running_var - bn_running_var.type(torch.bfloat16))**2).mean()}")
                
                print("-------------- FP32 - BF16 end --------------------")
                 """

                BatchNorm_bf16_layer = Compare_BatchNorm_Precision(num_features=bn_weight.shape, intermediate_result=True, chunk=1024, compute_type=torch.bfloat16)
                BatchNorm_bf16_layer.load_params(bn_weight, bn_bias, bn_running_mean, bn_running_var)

                folder_name = f"idx_{index}_{file_name}"
                input_pkl_file_name = f"{epoch}_fwd_input.pkl"
                input_pkl_file_path = os.path.join(option.save_path, folder_name, input_pkl_file_name)

                with open(input_pkl_file_path, "rb") as f:
                    input_tensor = pickle.load(f)

                output, output_fp16, fp_zero_mean_count, fp_total_mean_count, fp_zero_std_count, fp_total_std_count = BatchNorm_fp16_layer.forward(input_tensor, inference=False)
                _, output_bf16, bf_zero_mean_count, bf_total_mean_count, bf_zero_std_count, bf_total_std_count = BatchNorm_bf16_layer.forward(input_tensor, inference=False)


                C = output.shape[1]

                output_t = output.transpose(0, 1).reshape(C, -1)
                output_fp16_t = output_fp16.transpose(0, 1).reshape(C, -1)
                fp16_mse_dist = ((output_t - output_fp16_t)**2).mean(-1)
                fp16_L_inf = ((output_t - output_fp16_t)**2).abs().max(-1)[0]
                
                output_bf16_t = output_bf16.transpose(0, 1).reshape(C, -1)
                bf16_mse_dist = ((output_t - output_bf16_t)**2).mean(-1)
                bf16_L_inf = ((output_t - output_bf16_t)**2).abs().max(-1)[0]
                
                csv_dict = {"epoch": epoch, "layer": file_name, "FP32-FP16_mse": fp16_mse_dist, "FP32-FP16_L_inf":fp16_L_inf, 
                "FP32-FP16 mean zero setting Error" : f"{fp_zero_mean_count} / {fp_total_mean_count}", 
                "FP32-FP16 mean std setting Error" : f"{fp_zero_std_count} / {fp_total_std_count}" ,
                "FP32-BF16_mse" : bf16_mse_dist, "FP32-bf16_L_inf":bf16_L_inf, 
                "FP32-BF16 mean zero setting Error" : f"{bf_zero_mean_count} / {bf_total_mean_count}", 
                "FP32-BF16 mean std setting Error" : f"{bf_zero_std_count} / {bf_total_std_count}" ,
                
                }
                print(csv_dict)
                result_df = result_df.append(csv_dict, ignore_index=True)

result_df.to_csv("./precision_compare.csv", index=False)


0
idx :  3
layer name :  conv2_x.1.residual_function.1
chunk based sum result : 160785/8323072 = 1.9317988157272339%
chunk based sum result : 1353443/8323072 = 16.26133918762207%
chunk based sum result : 1426712/8323072 = 17.141651153564453%
chunk based sum result : 3479627/8323072 = 41.807003021240234%
{'epoch': 0, 'layer': 'conv2_x.1.residual_function.1', 'FP32-FP16_mse': tensor([0.0039, 0.0122, 0.0177, 0.0150, 0.0040, 0.0040, 0.0022, 0.0045, 0.0034,
        0.0036, 0.0052, 0.0150, 0.0042, 0.0047, 0.0036, 0.0026, 0.0055, 0.0035,
        0.0032, 0.0122, 0.0023, 0.0097, 0.0072, 0.0032, 0.0027, 0.0106, 0.0025,
        0.0036, 0.0051, 0.0051, 0.0041, 0.0028, 0.0038, 0.0065, 0.0019, 0.0034,
        0.0029, 0.0042, 0.0026, 0.0032, 0.0042, 0.0034, 0.0095, 0.0108, 0.0052,
        0.0147, 0.0020, 0.0033, 0.0109, 0.0030, 0.0029, 0.0026, 0.0143, 0.0130,
        0.0033, 0.0155, 0.0031, 0.0039, 0.0041, 0.0115, 0.0034, 0.0030, 0.0035,
        0.0035]), 'FP32-FP16_L_inf': tensor([0.0950, 0.1259, 0.

In [None]:
# backup code momentum check

class Compare_BatchNorm_Precision(nn.Module):
    # this is custom batchnorm different precision
    # Considered https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py and rangeBN 


    def __init__(self, num_features, intermediate_result = False, chunk=1024, dim=1, momentum=0.9, affine=True, eps=1e-5, compute_type=torch.float16):
        super(Compare_BatchNorm_Precision, self).__init__()
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.zeros(num_features))

        self.momentum = momentum
        self.dim = dim
        self.eps = 1e-10
        self.intermediate_result = intermediate_result
        if affine:
            self.bias = nn.Parameter(torch.Tensor(num_features))
            self.weight = nn.Parameter(torch.Tensor(num_features))
        self.compute_type = compute_type
        self.eps = eps
        self.chunk = chunk
        self.reset_params()

    def reset_params(self):
        if self.weight is not None:
            self.weight.data.uniform_()
        if self.bias is not None:
            self.bias.data.zero_()

    def load_params(self, weight, bias, running_mean, running_var):
        self.weight=weight
        self.bias=bias
        self.running_mean = running_mean
        self.running_var = running_var
        

    def forward(self, x, inference=True):
        x_ch = x.type(self.compute_type)        
        if not inference:
            if self.momentum !=0: # using momentum
                n,c,h,w = x.shape
                mean = x.transpose(1,0).view(c,-1).mean(dim=-1) # c axis mean
                std = x.transpose(1,0).view(c,-1).std(dim=-1) # c axis std

                mean_ch, zero_mean_count, total_mean_compute =  BatchNormMeanSim(x, chunk=self.chunk, AdderType=self.compute_type)
                std_ch, zero_std_count, total_std_compute = BatchNormStdSim(x, mean_ch, chunk=self.chunk, AdderType=self.compute_type)


                momentum_mean = self.momentum * self.running_mean + (1-self.momentum) *  mean
                momentum_mean_ch = self.momentum * self.running_mean.type(self.compute_type) + (1-self.momentum) * mean_ch

                momentum_std = self.momentum * self.running_var + (1 - self.momentum) * std
                momentum_std_ch = self.momentum * self.running_var.type(self.compute_type) + (1-self.momentum) * std_ch

                out = (x - momentum_mean.view(1, mean.size(0), 1, 1)) / \
                    (momentum_std.view(1, momentum_std.size(0), 1, 1) + self.eps)

                out_ch = (x_ch - momentum_mean_ch.view(1, momentum_mean_ch.size(0), 1, 1)) / \
                    (momentum_std_ch.view(1, momentum_std_ch.size(0), 1, 1) + self.eps)
            else:
                # not using momentum
                n,c,h,w = x.shape
                mean = x.transpose(1,0).view(c,-1).mean(dim=-1) # c axis mean
                std = x.transpose(1,0).view(c,-1).std(dim=-1) # c axis std

                mean_ch, zero_mean_count, total_mean_compute =  BatchNormMeanSim(x, chunk=self.chunk, AdderType=self.compute_type)
                std_ch, zero_std_count, total_std_compute = BatchNormStdSim(x, mean_ch, chunk=self.chunk, AdderType=self.compute_type)

                out = (x - mean.view(1, mean.size(0), 1, 1)) / \
                    (std.view(1, std.size(0), 1, 1) + self.eps)

                out_ch = (x_ch - mean_ch.view(1, mean_ch.size(0), 1, 1)) / \
                    (std_ch.view(1, std_ch.size(0), 1, 1) + self.eps)


        else:
            # using running_mean, and running_var
            c = x_ch.shape[1]
            mean = self.running_mean
            scale = self.running_var
            mean_ch = self.running_mean.type(self.compute_type)
            scale_ch = self.running_var.type(self.compute_type)

            out = (x - mean.view(1, mean.size(0), 1, 1)) / \
                (scale.view(1, scale.size(0), 1, 1) + self.eps)

            out_ch = (x_ch - mean_ch.view(1, mean_ch.size(0), 1, 1)) / \
                (scale_ch.view(1, scale.size(0), 1, 1) + self.eps.type(self.compute_type))
        
        if self.intermediate_result:

        if self.weight is not None:
            
            weight = self.weight
            weight_ch = self.weight.type(self.compute_type)
            out = out * weight.view(1, weight.size(0), 1, 1)
            out_ch = out_ch * weight_ch.view(1, weight_ch.size(0), 1, 1)

        if self.bias is not None:
            bias =self.bias
            out = out + bias.view(1, bias.size(0), 1, 1)
            bias_ch = self.bias.type(self.compute_type)
            out_ch = out_ch + bias_ch.view(1, bias_ch.size(0), 1, 1)

        return out, out_ch