In [1]:
import torch
from modeling_mixtral_predict import MixtralForCausalLM, load_thresholds
from transformers import AutoTokenizer, BitsAndBytesConfig
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4, 5"

def get_model(model_name, device_map, dtype=torch.bfloat16, use_cache=True):
	### 对up进行int2量化
	quantization_config = BitsAndBytesConfig(load_in_8bit=True,
										  llm_int8_skip_modules=["embed_tokens", "lm_head", "w3"])

	llm = MixtralForCausalLM.from_pretrained(
		model_name,
		device_map=device_map,
		use_cache=use_cache,
		quantization_config=quantization_config,
		torch_dtype=dtype,
	) 
	# save_dir = '/home/bcds/On-the-Fly_MoE_Inference/offloading/hqqsaved'
	# dtype = torch.float16
	# llm = MixtralHQQ.from_quantized(save_dir, compute_dtype=dtype, device='cuda:0', use_cache=True)
	# HQQLinear.set_backend(HQQBackend.PYTORCH)
	tokenizer = AutoTokenizer.from_pretrained(model_name)
	tokenizer.pad_token = tokenizer.eos_token
	tokenizer.pad_token_id = tokenizer.eos_token_id

	return llm, tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import json
with open('../path.json', 'r') as f:
    path = json.load(f)
    model_name = path['mixtral']
    threshold_path = path["mixtral_threshold"]
with open('../quantize/device_map_1.json', 'r') as f:
    device_map = json.load(f)

filepath = str(0.8).replace('.', '_')
load_thresholds(f'{threshold_path}/thresholds_{filepath}.pt', use_average=False,)

llm, tokenizer = get_model(model_name, device_map=device_map)

  up_th = torch.load(threshold_path, map_location='cuda')[f"{use_type}_proj_states_thresholds"]
  self.up_threshold = torch.tensor(up_th[self.layer_idx][self.expert_idx])


Thresholds loaded from /home/bcds/On-the-Fly_MoE_Inference/saving/threshold/c4_mixtral/thresholds_0_8.pt


Loading checkpoint shards: 100%|██████████| 19/19 [00:20<00:00,  1.10s/it]


In [11]:
llm.model.layers[0].block_sparse_moe.experts[0].w1.state


MatmulLtState(_tile_indices=None, force_no_igemmlt=False, CB=None, CxB=None, SB=None, SCB=None, CxBt=None, SBt=None, CBt=None, subB=None, outlier_pool=None, idx=None)

In [10]:
for params in llm.model.layers[0].block_sparse_moe.experts[0].w1.parameters():
    print(params)

Parameter containing:
Parameter(Int8Params([[-11,  33, -50,  ..., -51,   2, -41],
            [ -4,   7, -57,  ...,  60,  36, -12],
            [-45,  56,   5,  ...,  52,  -5, -15],
            ...,
            [-12,  -2,  26,  ...,   8, -42, -38],
            [ 24,  42,  16,  ...,  -8,  20, -22],
            [  9,  33,  95,  ...,  -7, -18,  29]], device='cuda:0',
           dtype=torch.int8))


In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6"
import sys
sys.path.append("/home/bcds/On-the-Fly_MoE_Inference/quantize")

from modeling_teacher import MixtralForCausalLM as MixtralTeacher
from modeling_mixtral_test import MixtralForCausalLM, load_thresholds
from transformers import AutoTokenizer, MixtralConfig
import json
import torch

with open("../quantize/device_map.json", "r") as f:
	sd = json.load(f)

with open("../quantize/device_map_1.json", "r") as f:
	td = json.load(f)

def prepare_model(model_name, is_eval=False, has_atten=False, sparsity=80):
    config = MixtralConfig(output_router_logits=True, use_cache=False, output_hidden_states=False)
    # load_thresholds("/home/bcds/On-the-Fly_MoE_Inference/saving/threshold/c4_mixtral/thresholds_0_5.pt", use_average=False)
    model = MixtralTeacher.from_pretrained(
        pretrained_model_name_or_path=model_name,
        config=config,
        torch_dtype=torch.float16,
        device_map=sd,
    )
    return model

model_name = "mixtral"
dataset_name = "fineweb"
with open('../path.json', 'r') as file:
    paths = json.load(file)
    model_name = paths.get(model_name, '')
    fineweb_path = paths.get(dataset_name,)

sparsity = 0.5
## 用这个去画一下余弦相似度的图
student = prepare_model(model_name, is_eval=False, has_atten=True, sparsity=sparsity)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 19/19 [00:25<00:00,  1.36s/it]


In [3]:
from datasets import load_dataset
from tqdm import tqdm

fineweb = load_dataset("parquet", data_files=fineweb_path)        
tokenizer = AutoTokenizer.from_pretrained(model_name)

#### (500,128)
test_samples = 1
texts = fineweb["train"]["text"][:test_samples]
tokenizer.pad_token = tokenizer.eos_token
# print(texts)

import torch.nn.functional as F
spar_model_logits = []
for text in tqdm(texts):
    inputs = tokenizer(text, padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    # inputs["labels"] = inputs.input_ids.clone()
    with torch.no_grad():
        output = student(**inputs)
    print(output["router_logits"])

100%|██████████| 1/1 [00:00<00:00,  6.65it/s]

(tensor([[[0, 0, 0,  ..., 1, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 1, 1,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 1],
         [1, 0, 0,  ..., 0, 1, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 0, 1]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]]]), tensor([[[0, 0, 0,  ..., 1, 1, 1],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 1,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 1, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 1,  ..., 0, 1, 0]],

        [[0, 1, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 1]]]), tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 1, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 1]




In [10]:
print(len(output["router_logits"])) ### 32层
print(output["router_logits"][0].size()) ### (expert_nums, top2, seq_len)

32
torch.Size([8, 2, 128])


In [5]:
recalls = [[0 for _ in range(8)] for _ in range(31)]
for i in range(31):
    for j in range(8):
        recalls[i][j] = student.model.layers[i+1].block_sparse_moe.experts[j].get_ratio()

counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting start....
counting sta

In [7]:
layer_aves = []
for i in range(31):
    layer_ave = 0
    for j in range(8):
        layer_ave += recalls[i][j]
    layer_aves.append(layer_ave/8)
    print(f"{i} : {layer_ave/8}")


0 : 0.9929167207616273
1 : 0.993652953224467
2 : 0.9935663523101625
3 : 0.9935381470963337
4 : 0.9906454216999487
5 : 0.9856903147916639
6 : 0.9808128588026573
7 : 0.9728984419980262
8 : 0.974595306383042
9 : 0.9808898134666808
10 : 0.9612637478555939
11 : 0.9615163525617088
12 : 0.9843525421186244
13 : 0.981803754173869
14 : 0.9789732749681834
15 : 0.9802407381094386
16 : 0.9821938400299352
17 : 0.9857402458333276
18 : 0.9818534171565407
19 : 0.9805381052721885
20 : 0.9819862943701141
21 : 0.9828528180270167
22 : 0.9841985369809362
23 : 0.9843406200126629
24 : 0.9843022985078043
25 : 0.9843391257515294
26 : 0.98593894068149
27 : 0.9796869710832403
28 : 0.9740712105702827
29 : 0.9421484266151791
30 : 0.8699232734621583


In [8]:
layer_aves

[0.9929167207616273,
 0.993652953224467,
 0.9935663523101625,
 0.9935381470963337,
 0.9906454216999487,
 0.9856903147916639,
 0.9808128588026573,
 0.9728984419980262,
 0.974595306383042,
 0.9808898134666808,
 0.9612637478555939,
 0.9615163525617088,
 0.9843525421186244,
 0.981803754173869,
 0.9789732749681834,
 0.9802407381094386,
 0.9821938400299352,
 0.9857402458333276,
 0.9818534171565407,
 0.9805381052721885,
 0.9819862943701141,
 0.9828528180270167,
 0.9841985369809362,
 0.9843406200126629,
 0.9843022985078043,
 0.9843391257515294,
 0.98593894068149,
 0.9796869710832403,
 0.9740712105702827,
 0.9421484266151791,
 0.8699232734621583]

In [None]:
# 0.7909
# 0.8449
# 0.8125
# 0.8770
# 0.8607
# 0.8450
# 0.8374
# 0.8356
# 0.8528
# 0.8780
# 0.8976
# 0.8753
# 0.9124
# 0.9138
# 0.9012
# 0.9044
# 0.9051
# 0.9060
# 0.8689
# 0.8621
# 0.8787
# 0.8970
# 0.8855
# 0.9106
# 0.9163
# 0.9252
# 0.9173
# 0.9120
# 0.8901
# 0.9002
# 0.8945