### 1. SQNR of the quantized output of the UNet

In [None]:
import matplotlib.pyplot as plt
import torch
import re
# 假设你的日志文件名为'run.log'
# filename = '/home/fangtongcheng/diffuser-dev/analysis_tools/error_func/quant_error/unet_output_error/sdxl_turbo_w4a32_minmax/run.log'
filename = "/share/public/diffusion_quant/text2img_diffusers_ptq/w8a8_test/sensitivity_quant_error_unet_output.log"
level_0 = 'layer'  # top/quant_block/layer
level_1 = 'layer'  # block/layer 因为layer的命名不同于Block，存在特殊情况，因此要分开处理

analysis_target_pattern = re.compile(r'\bquant_error_unet_output\b', re.IGNORECASE)
top_block_pattern = re.compile(r'\btop level blocks\b', re.IGNORECASE)
lower_block_pattern = re.compile(r'\blower level blocks\b', re.IGNORECASE)
layer_pattern = re.compile(r'\bthe layers\b', re.IGNORECASE)

analysis_target_pattern_matched = False
splits_line_number = []
with open(filename, 'r') as file:
    for line_number, line in enumerate(file, start=1):
        if top_block_pattern.search(line):
            print(f'Top Level Block found at line {line_number}')
            top_block_line_number = line_number  # if multiple top_block is found, use the last
        elif lower_block_pattern.search(line):
            print(f'Lower Level Block found at line {line_number}')
            lower_block_line_number = line_number
        elif layer_pattern.search(line):
            print(f'Layers found at line {line_number}')
            layer_line_number = line_number
        elif analysis_target_pattern.search(line):
            print(f'Analysis target pattern matched')
            analysis_target_pattern_matched = True

splits_line_number = [top_block_line_number, lower_block_line_number, layer_line_number]
splits_line_number.append(line_number)
print(f'Final line number {line_number}')
assert len(splits_line_number) == 4
assert analysis_target_pattern_matched is True

# 分别绘制不同层次的error
if level_0 == 'top':
# 指定读取的行数范围
    start_line = splits_line_number[0]
    end_line = splits_line_number[1]
elif level_0== 'quant_block':
    start_line = splits_line_number[1]
    end_line = splits_line_number[2]
else:
    start_line = splits_line_number[2]
    end_line = splits_line_number[3]

# 读取指定行数的日志数据
with open(filename, 'r') as file:
    lines = file.readlines()[start_line-1:end_line]

modules = []
errors = []
for line in lines:
    # print(line, type(line))
    if "model." in line:
        # 找到名字栏
        first_m = line.find('m')
        start = line.find('m', first_m+1)
        start = line.find('.', start)
        end = line.find(':', start)
        modules.append(line[start+1:end])
        # print(line[start+1:end])
    if "SQNR:" in line:
        # 数值栏
        start = line.find('S')
        start = line.find('S', start+1)
        start = line.find(':', start)
        end = line.find('B')
        errors.append(float(line[start+1:end-1]))

print(len(modules))
data = list(zip(modules, errors))


SORT_TYPE = 'error'   # 'layer_name' / 'error'
extract_min = True
extract_min_rate = 0.1
descending = False if error_type == 'SQNR' else True # when use MSE, get the largest ones, sort in descending ordef
# 对数据进行排序，使得mid_block的点位于up_block和down_block的点之间
# print(data[-2])

data_ = sorted(data, key=lambda x: x[1], reverse=descending) # reverse = True, sort in descending
if extract_min:
    num_min = int(extract_min_rate*len(modules))
    data_ = data_[:num_min]
if SORT_TYPE == 'layer_name':
    data_[5:-1] = sorted(data_[5:-1], key=lambda x: (x[0].split('.')[0], x[0]))  # 用于layer排序，因为layer字段里有一些不待block

modules, errors = zip(*data_)

# 记得命名为 “wxax” 对应的敏感层列表
if level_1=='layer' and extract_min:
    torch.save(modules, './sensitivity_log/sensitive_layers_list_w4a32_20.pt')
elif level_1=='block' and extract_min:
    torch.save(modules, './sensitivity_log/sensitive_blocks_list_w4a32_5.pt')

# 使用matplotlib来绘制折线图
plt.figure(figsize=(40, 8))
plt.plot(modules, errors, marker='o')
plt.xlabel('Blocks')
plt.ylabel('SQNR (dB)')
plt.title('SQNR for different blocks')
plt.grid(True)
plt.xticks(rotation=90, fontsize=14)
plt.show()


### 2. the error from the quantized weight

In [None]:
import matplotlib.pyplot as plt

# 假设你的日志文件名为'run.log'
# filename = '/home/fangtongcheng/diffuser-dev/analysis_tools/error_func/quant_error/quant_weight_error/sdxl_turbo_w4a32_minmax/run.log'
filename = '/share/public/diffusion_quant/text2img_diffusers_ptq/w8a8_test/sensitivity_quant_error_weight.log'

analysis_target_pattern = re.compile(r'\bquant_error_weightg', re.IGNORECASE)
analysis_target_pattern_matched = False
with open(filename, 'r') as file:
    for line_number, line in enumerate(file, start=1):
        if analysis_target_pattern.search(line):
            analysis_target_pattern_matched = True
            analysis_target_pattern_line_number = line_number
            print(f'Matched Weight Quant Error line number: {line_number}')

print(f'Final line number {line_number}')
final_line_number = line_number
assert analysis_target_pattern_matched is True

start_line = analysis_target_pattern_line_number
end_line = final_line_number

# 读取指定行数的日志数据
with open(filename, 'r') as file:
    lines = file.readlines()[start_line-1:end_line]

error_type = "SQNR"  # MSE or SQNR

level = 'layer'  # only layer supported for now
modules = []
sqnrs = []
mses = []
for line in lines:
    # print(line, type(line))
    if "model." in line:
        # 找到名字栏
        first_m = line.find('m')
        start = line.find('m', first_m+1)
        start = line.find('.', start)
        end = line.find(':', start)
        modules.append(line[start+1:end])
        # print(line[start+1:end])
    if "SQNR:" in line:
        # 数值栏
        start = line.find('S')
        start = line.find('S', start+1)
        start = line.find(':', start)
        end = line.find('B')
        sqnrs.append(float(line[start+1:end-1]))
        # print(line[start+1:end-1])
        # for quant_error
        start = line.find('M')
        start = line.find(':', start)
        end = line.find('x')
        mses.append(float(line[start+1:end]))
        
print(len(modules))

SORT_TYPE = 'error'   # 'error'
extract_min = True
extract_min_rate = 0.15
descending = False if error_type == 'SQNR' else True # when use MSE, get the largest ones, sort in descending ordef

data = list(zip(modules, sqnrs))
# 对数据进行排序，使得mid_block的点位于up_block和down_block的点之间
# print(data[-2])
if level=='layer':
    data_ = sorted(data, key=lambda x: x[1], reverse=descending) # reverse = True, sort in descending
    if extract_min:
        num_min = int(extract_min_rate*len(modules))
        data_ = data_[:num_min]
    if SORT_TYPE == 'layer_name':
        data_[5:-1] = sorted(data_[5:-1], key=lambda x: (x[0].split('.')[0], x[0]))  # 用于layer排序，因为layer字段里有一些不待block

modules, errors = zip(*data_)

# 使用matplotlib来绘制折线图
if error_type=="SQNR":
    plt.figure(figsize=(35, 7))
    # plt.plot(new_modules, new_sqnrs, marker='o') if level =='layer' and extract_min else plt.plot(modules, sqnrs, marker='o')
    plt.plot(modules, errors, marker='o')
    plt.xlabel('Module')
    plt.ylabel('SQNR (dB)')
    plt.title('SQNR for quantized weight')
    plt.grid(True)
    plt.xticks(rotation=90, fontsize=13)
    plt.show()
    if level =='layer' and extract_min:
        torch.save(modules, './sensitivity_log/weight_error/sensitive_layers_w4a32_5_sqnr.pt')
        
elif error_type=="MSE":
    # 使用matplotlib来绘制折线图
    plt.figure(figsize=(80, 8))
    plt.plot(modules, errors, marker='o')
    plt.xlabel('Module')
    plt.ylabel('MSE (x1e-5)')
    plt.title('MSE for quantized weight')
    plt.grid(True)
    plt.xticks(rotation=90, fontsize=12)
    plt.show()
    if level=='layer' and extract_min:
        torch.save(modules, './sensitivity_log/weight_error/sensitive_layers_w8a32_5_mse.pt')


In [None]:
import matplotlib.pyplot as plt

# 假设你的日志文件名为'run.log'
filename = '/home/fangtongcheng/diffuser-dev/analysis_tools/error_func/quant_error/quant_activation_error/sdxl_turbo_w32a8_minmax/run.log'

analysis_target_pattern = re.compile(r'\bquant_act_weight\b', re.IGNORECASE)
analysis_target_pattern_matched = False
with open(filename, 'r') as file:
    for line_number, line in enumerate(file, start=1):
        if analysis_target_pattern.search(line):
            analysis_target_pattern_matched = True
            analysis_target_pattern_line_number = line_number
            print(f'Matched Weight Quant Error line number: {line_number}')

print(f'Final line number {line_number}')
final_line_number = line_number
assert analysis_target_pattern_matched is True

start_line = analysis_target_pattern_line_number
end_line = final_line_number

# 读取指定行数的日志数据
with open(filename, 'r') as file:
    lines = file.readlines()[start_line-1:end_line]

error_type = "SQNR"  # MSE or SQNR

level = 'layer'
extract_min = True
modules = []
sqnrs = []
mses = []
kurts = []
for line in lines:
    # print(line, type(line))
    if "model." in line:
        # 找到名字栏
        first_m = line.find('m')
        start = line.find('m', first_m+1)
        start = line.find('.', start)
        end = line.find(':', start)
        modules.append(line[start+1:end])
        # print(line[start+1:end])
    elif "SQNR:" in line:
        # 数值栏
        start = line.find('S')
        start = line.find('S', start+1)
        start = line.find(':', start)
        end = line.find('B')
        sqnrs.append(float(line[start+1:end-1]))
        # print(line[start+1:end-1])
        # for quant_error
        start = line.find('M')
        start = line.find(':', start)
        end = line.find('x')
        mses.append(float(line[start+1:end]))
    elif "Kurt" in line:
        start = line.find('K')
        start = line.find(':', start)
        end = line.find(' ', start)
        #print(line[start+1:end])
        kurts.append(float(line[start+1:end]))
        
print(len(modules))
SORT_TYPE = 'error'   # 'error'
extract_min = True
extract_min_rate = 0.15
descending = False if error_type == 'SQNR' else True # when use MSE, get the largest ones, sort in descending ordef

data = list(zip(modules, sqnrs))
# 对数据进行排序，使得mid_block的点位于up_block和down_block的点之间
# print(data[-2])
if level=='layer':
    data_ = sorted(data, key=lambda x: x[1], reverse=descending) # reverse = True, sort in descending
    if extract_min:
        num_min = int(extract_min_rate*len(modules))
        data_ = data_[:num_min]
    if SORT_TYPE == 'layer_name':
        data_[5:-1] = sorted(data_[5:-1], key=lambda x: (x[0].split('.')[0], x[0]))  # 用于layer排序，因为layer字段里有一些不待block

modules, errors = zip(*data_)
# 分离排序后的blocks和sqnrs

# 使用matplotlib来绘制折线图
if error_type=="SQNR":
    plt.figure(figsize=(30, 7))
    plt.plot(new_modules, new_sqnrs, marker='o') if level =='layer' and extract_min else plt.plot(modules, sqnrs, marker='o')
    plt.xlabel('Module')
    plt.ylabel('SQNR (dB)')
    plt.title('SQNR for quantized weight')
    plt.grid(True)
    plt.xticks(rotation=80, fontsize=13)
    plt.show()
    if level=='layer' and extract_min:
        torch.save(new_modules, './sensitivity_log/act_error/sensitive_layers_w8a8_5_sqnr.pt')
elif error_type=="MSE":
    # 使用matplotlib来绘制折线图
    plt.figure(figsize=(40, 8))
    plt.plot(new_modules, new_mses, marker='o') if level =='layer' and extract_min else plt.plot(modules, mses, marker='o')
    plt.xlabel('Module')
    plt.ylabel('MSE (x1e-5)')
    plt.title('MSE for quantized weight')
    plt.grid(True)
    plt.xticks(rotation=90, fontsize=12)
    plt.show()
    if level=='layer' and extract_min:
        torch.save(new_modules, './sensitivity_log/act_error/sensitive_layers_w8a8_5_mse.pt')
