## the distribution of the weight and the original activation of SDXL-Turbo (the input of a certain layer)

In [None]:
import torch
import torch.nn as nn
from typing import Union
import json
from pytorch_lightning import seed_everything
import logging
from tqdm import tqdm
from torch.cuda import amp
import matplotlib.pyplot as plt
import numpy as np

from diffusers import StableDiffusionPipeline, UNet2DModel, UNet2DConditionModel, LCMScheduler
from sdxl_pipeline import StableDiffusionXLPipeline
from qdiff.utils import DataSaverHook, StopForwardException
from qdiff.models.quant_layer import QuantLayer
from qdiff.models.quant_model import QuantModel
from qdiff.models.quant_block import BaseQuantBlock

from qdiff.models.quant_block_forward_func import convert_model_split, convert_transformer_storable, set_shortcut_split


In [None]:
def get_model(model_id: str="Lykon/dreamshaper-7", cache_dir: str="/share/public/diffusion_quant/huggingface/hub", type: str="lcm_lora"):
    print(f"the weight is from {model_id}")
    # pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
    if 'xl' in type:
        pipe = StableDiffusionXLPipeline.from_pretrained(model_id, cache_dir=cache_dir)
    else:
        pipe = StableDiffusionPipeline.from_pretrained(model_id, cache_dir=cache_dir)

    # print(type(model))
    if 'lcm' in type:
        pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)
    if 'lora' in type:
        # load and fuse lcm lora
        adapter_id = "latent-consistency/lcm-lora-sdv1-5"
        pipe.load_lora_weights(adapter_id)
        pipe.fuse_lora()

    model = pipe.unet

    # convert_model_split(model)
    convert_transformer_storable(model)
    
    model.cuda(4)
    model.eval()
    return model, pipe

model, pipe = get_model(model_id="stabilityai/sdxl-turbo", type="sdxl_turbo")

In [None]:
sensitive_out = torch.load("../error_func/sensitivity_log/unet_out_error/sensitive_layers_list_w4a32_5.pt")
len(sensitive_out)
sensitive_out

In [None]:
sensitive_weight = torch.load("../error_func/sensitivity_log/weight_error/sensitive_layers_w4a32_5_sqnr.pt")
(sensitive_weight)

In [None]:
sensitive_layers = []
for layer in sensitive_out:
    if layer in sensitive_weight:
        sensitive_layers.append(layer)

print(len(sensitive_layers))
sensitive_layers

### 1. the distribution of the activation

In [None]:
sensitive_acts = torch.load("../error_func/sensitivity_log/act_error/sensitive_layers_w8a8_5_sqnr.pt")
sensitive_acts

In [None]:
logger = logging.getLogger(__name__)
seed_everything(42)

In [None]:
def prepare_coco_text_and_image(json_file):
    info = json.load(open(json_file, 'r'))
    annotation_list = info["annotations"]
    image_caption_dict = {}
    for annotation_dict in annotation_list:
        if annotation_dict["image_id"] in image_caption_dict.keys():
            image_caption_dict[annotation_dict["image_id"]].append(annotation_dict["caption"])
        else:
            image_caption_dict[annotation_dict["image_id"]] = [annotation_dict["caption"]]
    captions = list(image_caption_dict.values())
    image_ids = list(image_caption_dict.keys())

    active_captions = []
    for texts in captions:
        active_captions.append(texts[0])

    image_paths = []
    for image_id in image_ids:
        image_paths.append("/share/public/diffusion_quant/coco/coco/val2014/"+f"COCO_val2014_{image_id:012}.jpg")
    return active_captions, image_paths


In [None]:
# for get the input of the certain module
def find_module_by_name(model, name):
    for module_name, module in model.named_modules():
        if module_name == name:
            return module
    return None

In [None]:
name = 'conv_in'  # 'up_blocks.2.resnets.2.conv_shortcut'
module = find_module_by_name(model, name)
module  # display the module

In [None]:
class DataSaverHook:
    """
    Forward hook that stores the input and output of a block
    """
    def __init__(self, store_input=False, store_output=False, stop_forward=False):
        self.store_input = store_input
        self.store_output = store_output
        self.stop_forward = stop_forward

        self.input_store = None
        self.output_store = None

    def __call__(self, module, input_batch, output_batch):
        if self.store_input:
            self.input_store = input_batch 
        if self.store_output:
            self.output_store = output_batch 
        if self.stop_forward:
            raise StopForwardException

In [None]:
class GetLayerInpOut_SDXL:
    def __init__(self, model: QuantModel, layer: Union[QuantLayer, BaseQuantBlock, nn.Module],
                 device: torch.device, asym: bool = False, act_quant: bool = False):
        self.model = model
        self.layer = layer
        self.asym = asym
        # self.device = device
        self.act_quant = act_quant
        self.data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=False)

    def __call__(self, x, timesteps, context=None, added_conds=None):
        self.model.eval()
        # self.model.set_quant_state(False, False)

        handle = self.layer.register_forward_hook(self.data_saver)  
        with torch.no_grad():
            try:
                _ = self.model(x, timesteps, context, added_cond_kwargs=added_conds)
            except StopForwardException:
                pass

        handle.remove()

        if len(self.data_saver.input_store) > 1 and len(self.data_saver.input_store) < 7 and torch.is_tensor(self.data_saver.input_store[1]):
            return (self.data_saver.input_store[0].detach(),  
                self.data_saver.input_store[1].detach())
        elif len(self.data_saver.input_store) == 7:
            # 针对QuantTransformerBlock 有7个输入（待优化）
            input_tuple = []
            for input in self.data_saver.input_store:
                if input == None:
                    input_tuple.append(input)
                else:
                    input_tuple.append(input.detach())
            return tuple(input_tuple)  # difference
        else:
            return self.data_saver.input_store[0].detach()

In [None]:
model.device

In [None]:
def inference_sdxl_turbo(prompt, pipe):
    print("#######################################################################")
    # disable guidance_scale by passing 0
    image = pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0)[0].images
    return image


def sample(prompt, unet, pipe, batch_size, quant_inference = False, is_fp16 = False):
    torch_device = "cuda" if torch.cuda.is_available() else "cpu"
    generator = torch.manual_seed(42)  # Seed generator to create the initial latent noise
    total = len(prompt)
    image_folder = './act_distribution'
    # n = 16  # 按批量推理
    num = total // batch_size
    assert num==1, "num==1 should be true"
    img_id = 0
    logger.info(f"starting from image {img_id}")
    # total_n_samples = max_images

    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)
    
    pipe.to(unet.device)
    with torch.no_grad():
        for i in tqdm(
            range(num), desc="Generating activations for plotting."
        ):
            with amp.autocast(enabled=False):
                image = inference_sdxl_turbo(prompt[batch_size*i:batch_size*(i+1)], pipe)

            # for j in range(batch_size):
            #     image[j].save(f"/home/fangtongcheng/diffuser-dev/analysis_tools/distribution/act_distribution_img/{img_id}.png")
            #     img_id += 1


In [None]:
def get_data(model, module, pipe):
    data_saver = DataSaverHook(store_input=True, store_output=True, stop_forward=False)
    handle = module.register_forward_hook(data_saver) 
    
    json_file = "/share/public/diffusion_quant/coco/coco/annotations/captions_val2014.json"
    prompt_list, image_path = prepare_coco_text_and_image(json_file=json_file)
    prompts = prompt_list[0:8]

    with torch.no_grad():
        sample(prompts, model, pipe, batch_size=8)
    handle.remove()

    if len(data_saver.input_store) > 1 and len(data_saver.input_store) < 7 and torch.is_tensor(data_saver.input_store[1]):
        # the input of the ResnetBlock2D contains two tensors
        return (data_saver.input_store[0].detach(), 
        data_saver.input_store[1].detach()), data_saver.output_store.detach()
    elif len(data_saver.input_store) == 7:
        # 针对QuantTransformerBlock 有7个输入（待优化）
        input_tuple = []
        for input in data_saver.input_store:
            if input == None:
                input_tuple.append(input)
            else:
                input_tuple.append(input.detach())
        return tuple(input_tuple), data_saver.output_store.detach()
    else:
        return data_saver.input_store[0].detach(), data_saver.output_store.detach()

In [None]:
name = 'conv_in'  # 'up_blocks.2.resnets.2.conv_shortcut'
module = find_module_by_name(model, name)
module  # display the module

input_data, output_data = get_data(model, module, pipe)

In [None]:
input_data.shape

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.colors as colors

import plotly.graph_objects as go
import numpy as np
from plotly.offline import init_notebook_mode, iplot
init_notebook_mode(connected=True)

def plot_activation_3d(input_data, name, color_top, type='plotly'):
    # color_top: the value represented by the top of the color bar
    # 将权重张量转换为numpy数组，并取绝对值
    inputs_np = np.abs(input_data.detach().cpu().numpy())

    if len(inputs_np.shape)==4:
        # 获取输入通道和输出通道的数量
        batch_size, in_channels = inputs_np.shape[0:2]
        x_data = batch_size
        y_data = in_channels
        # 在H、W维度上取均值
        inputs_np = inputs_np.mean(axis=(2, 3))
        # reshape成（out_channel，in_channel）维度的张量
        inputs_np = inputs_np.reshape(batch_size, in_channels)
        x_label = "Batch Size"
    elif len(inputs_np.shape)==3:
        token_length = inputs_np.shape[1]
        in_channels = inputs_np.shape[2]
        x_data = token_length
        y_data = in_channels
        inputs_np = inputs_np.mean(axis=(0))  # along batch dim
        # reshape成（out_channel，in_channel）维度的张量
        inputs_np = inputs_np.reshape(token_length, in_channels)
        x_label = "Tokens"
    else:
        # TODO: 处理只有两个维度的输入
        batch_size = inputs_np.shape[0]
        in_channels = input_data.shape[1]
        x_data = batch_size
        y_data = in_channels
        inputs_np = inputs_np.reshape(batch_size, in_channels)
        x_label = "Batch Size"
        
    # 创建X，Y
    _x = np.arange(x_data)
    _y = np.arange(y_data)
    _X, _Y = np.meshgrid(_x, _y)

    _X = _X.T
    _Y = _Y.T

    if type == 'matplotlib':
        # 创建一个新的图形和3D子图
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        # ax = plt.axes(projection='3d')
        
        # 绘制3D表面
        surf = ax.plot_surface(_X, _Y, inputs_np, cmap='coolwarm', vmin=np.min(inputs_np), vmax=color_top)
    
        # 设置轴标签
        ax.set_xlabel(x_label)
        ax.set_ylabel('Input Channels')
        ax.set_zlabel('Absolute Activations')
    
        ax.set_zlim([-np.max(np.abs(inputs_np)), np.max(np.abs(inputs_np))])
        print(np.max(np.abs(inputs_np)))
        # 添加颜色条
        fig.colorbar(surf)
        # 调整视图角度
        ax.view_init(elev=20, azim=20)
        # 显示图形
        plt.title('the input data of '+name)
        # 保存图像到文件
        plt.savefig(f'./distribution_plot/act/Transformer2d_model/3d_distribution/3d_acts_{name}.png')
        plt.show()
    elif type == 'plotly':
        fig = go.Figure()
        colormap_span = 3
        surf = go.Surface(x=_X, y=_Y, z=inputs_np, colorscale='viridis', cmin=np.min(inputs_np), \
                  cmax=np.max(inputs_np),opacity=0.5)
        fig.add_trace(surf)
        
        # Set z-limits
        # fig.update_layout(scene=dict(zaxis=dict(range=[-np.max(np.abs(inputs_np)), np.max(np.abs(inputs_np))])),
        #                   width=800, height=800,xaxis_title=x_label,yaxis_title='input channels'  # Set the figure size
        #                  )
        fig.update_layout(scene=dict(zaxis=dict()),
                          width=800, height=800,xaxis_title=x_label,yaxis_title='input channels'  # Set the figure size
                         )
        # Add color bar
        fig.update_layout()
        
        # Show the plot
        iplot(fig)


def plot_activation_pdf(input_data, channel_type = None, view_channel = 1, name=''):
    '''
    channel_type: 选择在哪一个维度去看数据分布PDF, 现阶段没用到, 直接看整个张量的数据分布
    view_channel: 选择在某一个channel去看数据分布PDF, 现阶段没用到, 直接看整个张量的数据分布
    '''
    # if channel_type=='output_channel':
    #     tensor = input_data[view_channel].reshape(-1)
    # elif channel_type=='input_channel':
    #     tensor = input_data[:,view_channel].reshape(-1)  # view不能处理内存不连续的张量
    # else:
    # 假设您的张量是 tensor
    tensor = input_data.reshape(-1)

    # 将张量转换为numpy数组
    numpy_array = tensor.cpu().numpy()

    # freq = (np.ones_like(numpy_array) / len(numpy_array))
    # 使用matplotlib的hist函数绘制分布图
    if  channel_type is not None:
        label= (name+': '+channel_type+'.'+str(view_channel)) 
    else:
        label = name

    plt.hist(numpy_array, bins=3000, label=label)  
    # 'auto'会自动选择最佳的bins数量

    plt.title(f'the distribution of the input data of the {name}')
    plt.xlabel('value')
    plt.ylabel('freq')

    plt.legend()
    plt.yscale('log')
    plt.savefig(f'./distribution_plot/act/Transformer2d_model/pdf_distribution/pdf_acts_{name}.png')
    plt.show()


def plot_activation_box(input_data, channel_type='input_channel', name=''):
    # 假设您的张量是 tensor
    tensor = input_data.cpu()  # 这是一个4阶张量，第0阶和第1阶分别代表输出通道和输入通道

    if len(tensor.shape)==4:
        if channel_type == 'input_channel':
            channels = [tensor[:, j].numpy().flatten() for j in range(tensor.shape[1])]
        else:
            raise RuntimeError("the channel_type is not the 'input_channel'")
        x_label = "channel_index"
    elif len(tensor.shape)==3:
        if channel_type == 'input_channel':
            channels = [tensor[:,:,j].numpy().flatten() for j in range(tensor.shape[2])]
            x_label = "channel_index"
        elif channel_type == 'tokens':
            channels = [tensor[:,j].numpy().flatten() for j in range(tensor.shape[1])]
            x_label = "token_index"
        else:
            raise RuntimeError("the channel_type is not the 'input_channel'")
        
    else:
        if channel_type == 'input_channel':
            channels = [tensor[:, j].numpy().flatten() for j in range(tensor.shape[1])]
        else:
            raise RuntimeError("the channel_type is not the 'input_channel'")
        x_label = "time_emb_channel_index"

    plt.figure(figsize=(22, 10))

    # 使用matplotlib的boxplot函数绘制箱状图
    bplot = plt.boxplot(channels, patch_artist=True, notch=True, vert=1)

    colors = ['pink', 'lightblue', 'lightgreen']
    for patch, color in zip(bplot['boxes'], colors):
        patch.set_facecolor(color)

    # 添加图例
    plt.legend([bplot["boxes"][0]], [name+': '+channel_type], loc='upper right')

    plt.title(f'box-plot of the input_data of the {name}')
    plt.xlabel(x_label)
    plt.ylabel('range')
    plt.savefig(f'./distribution_plot/act/Transformer2d_model/box_distribution/box_acts_{name}.png')
    plt.show()


def plot_activation_channel(input_data, channel_type, name):
    # 假设你的四维张量是tensor
    # tensor.shape = (batch_size, height, width, num_channels)

    # 计算每个输出通道的元素均值
    inputs_np = (input_data.detach().cpu().numpy())
    print(input_data.shape)
    if len(input_data.shape)==4:
        mean_values = inputs_np.mean(axis=(0, 2, 3))
        xs = np.arange(input_data.shape[1])
        x_label = "channel_index"
    elif len(input_data.shape)==3:
        if channel_type == 'input_channel':
            mean_values = inputs_np.mean(axis=(0, 1))
            xs = np.arange(input_data.shape[2])
            x_label = "channel_index"
        elif channel_type == 'tokens':
            mean_values = inputs_np.mean(axis=(0, 2))
            xs = np.arange(input_data.shape[1])
            x_label = "token_index"
    elif len(input_data)==2:
        # 创建输出通道的索引
        mean_values = inputs_np.mean(axis=(0))
        xs = np.arange(input_data.shape[1])
        x_label = "channel_index"

    plt.figure(figsize=(5, 3))
    # 绘制分布图
    print(inputs_np.max())
    plt.bar(xs, mean_values)
    plt.xlabel(x_label)
    plt.ylabel('range')
    plt.title('the input data of'+name)
    plt.savefig(f'./distribution_plot/act/Transformer2d_model/channel_distribution/channel_acts_{name}.png')
    plt.show()


#### 3D distribution

In [None]:
name = 'down_blocks.2.attentions.0.transformer_blocks.1.attn2.to_k'  # 'up_blocks.2.resnets.2.conv_shortcut'
module = find_module_by_name(model, name)
module  # display the module

input_data, output_data = get_data(model, module, pipe)

input_data = input_data[:,:1,:].permute([1,0,2])
diff = input_data - input_data[:,0,:].unsqueeze(1)
print(diff)

if type(input_data) is not tuple:
    plot_activation_3d(input_data, name, color_top=0.9)
elif len(input_data)==2:
    # the input of the resnet
    plot_activation_3d(input_data[0], name+'0', color_top=1.2)
    plot_activation_3d(input_data[1], name+'0', color_top=1.2)  # time embedding
else:
    # the input of the cross attention or self attention
    # TODO: 定位哪些attention是需要cross attention的
    plot_activation_3d(input_data[0], name+'1', color_top=1.2)
    # input_data[2]==None if self attention
    plot_activation_3d(input_data[2], name+'1', color_top=1.2)  # text embdding


#### 分布图

In [None]:
if type(input_data) is not tuple:
    plot_activation_pdf(input_data, channel_type = None, view_channel = 1, name=name)
elif len(input_data)==2:
    # the input of the resnet
    plot_activation_pdf(input_data[0], channel_type = None, view_channel = 1, name=name+'0')
    plot_activation_pdf(input_data[1], channel_type = None, view_channel = 1, name=name+'1')
else:
    # the input of the cross attention or self attention
    # TODO: 定位哪些attention是需要cross attention的
    plot_activation_pdf(input_data[0], channel_type = None, view_channel = 1, name=name+'0')
    # input_data[2]==None if self attention
    plot_activation_pdf(input_data[2], channel_type = None, view_channel = 1, name=name+'1')

#### 箱状图

In [None]:
if type(input_data) is not tuple:
    plot_activation_box(input_data, channel_type = 'input_channel', name=name)
elif len(input_data)==2:
    # the input of the resnet
    plot_activation_box(input_data[0], channel_type = 'input_channel', name=name+'0')
    plot_activation_box(input_data[1], channel_type = 'input_channel', name=name+'1')
else:
    # the input of the cross attention or self attention
    # TODO: 定位哪些attention是需要cross attention的
    plot_activation_box(input_data[0], channel_type = 'input_channel', name=name+'0')
    # input_data[2]==None if self attention
    plot_activation_box(input_data[2], channel_type = 'input_channel', name=name+'1')

#### 不同通道之间的数值差异可视化

In [None]:
if type(input_data) is not tuple:
    plot_activation_channel(input_data, channel_type = 'input_channel', name=name)
elif len(input_data)==2:
    # the input of the resnet
    plot_activation_channel(input_data[0], channel_type = 'input_channel', name=name+'0')
    plot_activation_channel(input_data[1], channel_type = 'input_channel', name=name+'1')
else:
    # the input of the cross attention or self attention
    # TODO: 定位哪些attention是需要cross attention的
    plot_activation_channel(input_data[0], channel_type = 'input_channel', name=name+'0')
    # input_data[2]==None if self attention
    plot_activation_channel(input_data[2], channel_type = 'input_channel', name=name+'1')

### the distribution of the weight

In [None]:
sensitive_out

In [None]:
sensitive_out_attention = []
for name in sensitive_out:
    if 'attention' in name:
        sensitive_out_attention.append(name)
sensitive_out_attention

In [None]:
sensitive_weight

In [None]:
# weight_id = sensitive_layers[12]+'.weight'  # 观察哪一层的权重
# weight_id = sensitive_out_attention[16]+'.weight'
weight_id = "up_blocks.0.attentions.2.transformer_blocks.2.attn2.to_v"+'.weight'

In [None]:
def get_weights(model, layer_name):
    for name, module in model.named_parameters():
        print(module.data.shape)
        if name == layer_name:
            return module.data

In [None]:
weights = get_weights(model, weight_id)

In [None]:
weights.shape

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.colors as colors

def plot_weight_3d(weights, weight_id, color_top):
    # 将权重张量转换为numpy数组，并取绝对值
    weights_np = np.abs(weights.cpu().numpy())  # !!!

    # 获取输入通道和输出通道的数量
    out_channels, in_channels = weights_np.shape[0:2]

    if len(weights_np.shape)==4:
        # if conv: 在H、W维度上取均值
        weights_np = weights_np.mean(axis=(2, 3))
        # reshape成（out_channel，in_channel）维度的张量
        weights_np = weights_np.reshape(out_channels, in_channels)


    # 创建一个新的图形和3D子图
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    # ax = plt.axes(projection='3d')

    # 创建X，Y
    _x = np.arange(out_channels)
    _y = np.arange(in_channels)
    _X, _Y = np.meshgrid(_x, _y)

    _X = _X.T
    _Y = _Y.T

    # 绘制3D表面
    weights_np = weights_np
    cs = weights_np
    surf = ax.plot_surface(_X, _Y, weights_np, cmap='coolwarm', vmin=np.min(weights_np), vmax=color_top)
    # points = ax.scatter3D(_X, _Y, weights_np, c=cs, s=0.5, cmap='coolwarm', vmin=np.min(weights_np), vmax=np.max(weights_np))

    # # 创建一个自定义的归一化对象
    # norm = colors.Normalize(vmin=np.min(weights_np), vmax=np.max(weights_np), clip=True)
    # # 绘制3D表面
    # surf = ax.plot_surface(_X, _Y, weights_np, cmap='coolwarm', norm=norm)


    # 设置轴标签
    ax.set_xlabel('Output Channels')
    ax.set_ylabel('Input Channels')
    ax.set_zlabel('Absolute Weights')

    # ax.set_zlim([np.min(weights_np), np.max(weights_np)])
    ax.set_zlim([-np.max(np.abs(weights_np)), np.max(np.abs(weights_np))])
    print(np.max(np.abs(weights_np)))

    # 添加颜色条
    # fig.colorbar(points)
    fig.colorbar(surf)

    # 保存图像到文件
    plt.savefig(f'./distribution_plot/weight/transformer2d_model/3d/3d_weight_{weight_id}.png')

    # 显示图形
    plt.show()


def plot_weight_pdf(weights, channel_type=None, view_channel = 1, weight_id=''):
    '''
    channel_type: 是否选择逐channel去看分布
    view_channel: 看第几个channel
    '''
    weights = weights.cpu()
    # linear or conv
    if channel_type=='output_channel':
        tensor = weights[view_channel].reshape(-1)
    elif channel_type=='input_channel':
        tensor = weights[:,view_channel].reshape(-1)  # view不能处理内存不连续的张量
    else:
        # 假设您的张量是 tensor
        tensor = weights.reshape(-1)

    # 将张量转换为numpy数组
    numpy_array = tensor.numpy()

    # 使用matplotlib的hist函数绘制分布图
    plt.hist(numpy_array, bins='auto', label=weight_id+': '+channel_type+'.'+str(view_channel) if  channel_type is not None else weight_id)  
    # 'auto'会自动选择最佳的bins数量

    plt.title('distribution')
    plt.xlabel('value')
    plt.ylabel('freq')

    plt.legend()
    plt.yscale('log')

    plt.savefig(f'./distribution_plot/weight//transformer2d_model/pdf/pdf_weight_{weight_id}.png')
    plt.show()


def plot_weight_box(weights, channel_type = 'output_channel', weight_id=''):
    # 假设您的张量是 tensor
    weights = weights.cpu()
    tensor = weights  # 这是一个4阶张量，第0阶和第1阶分别代表输出通道和输入通道

    if channel_type == 'output_channel':
        channels = [tensor[i].numpy().flatten() for i in range(tensor.shape[0])]
    elif channel_type == 'input_channel':
        channels = [tensor[:, j].numpy().flatten() for j in range(tensor.shape[1])]

    plt.figure(figsize=(22, 10))

    # 使用matplotlib的boxplot函数绘制箱状图
    flierprops = dict(marker='o', markersize=2)
    bplot = plt.boxplot(channels, patch_artist=True, notch=True, vert=1, flierprops=flierprops)

    colors = ['pink', 'lightblue', 'lightgreen']
    for patch, color in zip(bplot['boxes'], colors):
        patch.set_facecolor(color)

    # 添加图例
    plt.legend([bplot["boxes"][0]], [weight_id+': '+channel_type], loc='upper right')

    plt.title('box-plot')
    plt.xlabel('channel_index')
    plt.ylabel('range')

    plt.savefig(f'./distribution/distribution_plot/weight//transformer2d_model/box/box_weight_{weight_id}.png')
    plt.show()



#### 3D distribution

In [None]:
plot_weight_3d(weights, weight_id, color_top=0.03)

#### 分布图

In [None]:
plot_weight_pdf(weights, weight_id=weight_id)

#### 箱状图

In [None]:
plot_weight_box(weights, channel_type='output_channel', weight_id=weight_id)

#### the comparison between the sdxl-turbo and the original sdxl

In [None]:
model_id_non_turbo = "stabilityai/stable-diffusion-xl-base-1.0"
model = get_model(model_id=model_id_non_turbo, type="sdxl")
weight_id_non_turbo = weight_id  # 观察哪一层的权重
weights_non_turbo = get_weights(model, weight_id)

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

# 是否选择逐channel去看分布
channel_type = None
# 看第几个channel
view_channel = 1

if channel_type=='output_channel':
    flat_tensor1 = weights[view_channel].reshape(-1)
    flat_tensor2 = weights_non_turbo[view_channel].reshape(-1)
elif channel_type=='input_channel':
    flat_tensor1 = weights[:,view_channel].reshape(-1)  # view不能处理内存不连续的张量
    flat_tensor2 = weights_non_turbo[:,view_channel].reshape(-1)  # view不能处理内存不连续的张量
else:
    # 假设您的张量是 tensor
    flat_tensor1 = weights.reshape(-1)
    flat_tensor2 = weights_non_turbo.reshape(-1)

# 将张量转换为numpy数组
numpy_array1 = flat_tensor1.numpy()
numpy_array2 = flat_tensor2.numpy()

# 使用matplotlib的hist函数绘制分布图，设置weights参数
plt.hist(numpy_array2, bins=3000, alpha=0.5, label='sdxl_'+weight_id_non_turbo+': '+channel_type+'.'+str(view_channel) if  channel_type is not None else 'sdxl_'+weight_id_non_turbo)
plt.hist(numpy_array1, bins=3000, alpha=0.5, label='sdxl_turbo_'+weight_id+': '+channel_type+'.'+str(view_channel) if  channel_type is not None else 'sdxl_turbo_'+weight_id)


# 设置y轴的刻度为对数刻度
plt.yscale('log')
plt.title('distribution')
plt.xlabel('value')
plt.ylabel('freq')

plt.legend()
plt.show()




## 计算张量各种属性：信息熵、方差、离群点占比等

In [None]:
import torch
import numpy as np

def compute_weight_info(weights):
# 假设您的张量是 tensor
    tensor = weights.reshape(-1)

    # 将张量转换为numpy数组
    numpy_array = tensor.numpy()

    # 计算熵
    counts, _ = np.histogram(numpy_array, bins=3000)
    p = counts / counts.sum()
    p = p+1e-10
    entropy = -np.sum(p * np.log(p))

    # 计算离群值的分布程度
    # 基于四分位数范围（IQR）的离群值检测方法
    q1, q2 = np.percentile(numpy_array, [95, 5])
    iqr = q1 - q2
    threshold = 1.5 * iqr
    outliers = numpy_array[(numpy_array < (q1 - threshold)) | (numpy_array > (q2 + threshold))]
    outlier_count = len(outliers) / len(numpy_array)

    # 计算方差
    variance = np.var(numpy_array)

    print(f'entropy: {entropy}')
    print(f'outlier_percent: {outlier_count*1e3}x1^(-3)')
    print(f'variance: {variance*1e5}x10^(-5)')
    return entropy, outlier_count, variance

In [None]:
compute_weight_info(weights)

In [None]:
def get_weights_info(model):
    weight_names = []
    entropys = []
    variances = []
    outlier_freqs = []
    for name, module in model.named_parameters():
        print(module.data.shape)
        # print(line, type(line))
        weight_names.append(name)
        entropy, variance, outlier_freq = compute_weight_info(module.data)
        entropys.append(entropy)
        variances.append(variance)
        outlier_freqs.append(outlier_freq)
    return weight_names, entropys, variances, outlier_freqs

In [None]:
weight_names, entropys, variances, outlier_freqs = get_weights_info(model)
# 使用matplotlib来绘制折线图
plt.figure(figsize=(80, 6))
plt.plot(weight_names, entropys, marker='o')
plt.xlabel('weight of layers')
plt.ylabel('entropys')
plt.title('entropys for weights of different layers')
plt.grid(True)
plt.xticks(rotation=90, fontsize=2)
plt.show()

In [None]:
plt.figure(figsize=(80, 8))
plt.plot(weight_names, entropys, marker='o')
plt.xlabel('Blocks')
plt.ylabel('SQNR (dB)')
plt.title('SQNR for different blocks')
plt.grid(True)
plt.xticks(rotation=90, fontsize=6)
plt.show()

In [None]:
plt.figure(figsize=(80, 8))
plt.plot(weight_names, entropys, marker='o')
plt.xlabel('Blocks')
plt.ylabel('SQNR (dB)')
plt.title('SQNR for different blocks')
plt.grid(True)
plt.xticks(rotation=60, fontsize=6)
plt.show()