In [3]:
import torch
import json
# from modeling_llama_up import set_profile_mode
import os
import csv
from utils import get_model, set_seed

os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
### from path.json read paths of model and dataset
model_name = "mixtral"
dataset_name = "c4"
with open('../path.json', 'r') as file:
    paths = json.load(file)
    model_path = paths.get(model_name, '')
    dataset_path = paths.get(dataset_name, '')
    save_path = paths.get('chess_up_threshold','')
    print('model path:', model_path, '\ndataset path:', dataset_path, '\nsave path:', save_path)

set_seed(42)
# c4data = get_c4_data(model_path, dataset_path, sample_num = 400)
model = get_model(model_path)

  from .autonotebook import tqdm as notebook_tqdm
MixtralForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


model path: /home/lz/Mixtral-8x7B-v0.1 
dataset path: /home/lz/c4 
save path: /home/lz/On-the-Fly_MoE_Inference/saving/threshold/c4_mixtral_up


Loading checkpoint shards: 100%|██████████| 19/19 [00:37<00:00,  1.97s/it]


with sparsity of 0


In [4]:
model

MixtralForCausalLM(
  (model): MixtralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MixtralDecoderLayer(
        (self_attn): MixtralAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MixtralRotaryEmbedding()
        )
        (block_sparse_moe): MixtralSparseMoeBlock(
          (gate): Linear(in_features=4096, out_features=8, bias=False)
          (experts): ModuleList(
            (0-7): 8 x MixtralBLockSparseTop2MLP(
              (w1): Linear(in_features=4096, out_features=14336, bias=False)
              (w2): Linear(in_features=14336, out_features=4096, bias=False)
              (w3): Linear(in_features=4096, out_features=14336, bias=False)
        

In [4]:
from datasets import load_dataset
from transformers import AutoTokenizer

raw_datasets = load_dataset("/home/lz/c4")

tokenizer = AutoTokenizer.from_pretrained(model_path)
def process(example):
    ids = tokenizer.encode(example['text'])
    out = {'ids': ids, 'len': len(ids)}
    return out

tokenized = raw_datasets.map(process, desc='tokenizing raw datasets', num_proc=64)
import numpy as np
datasets = dict()

for split, dset in tokenized.items():
    datasets[split] = []
    length = np.sum(dset['len'])
    datasets[split] = np.ndarray((length, ), np.uint32)
    idx = 0
    for row in dset:
        datasets[split][idx:idx + row['len']] = row['ids']
        idx += row['len']
torch.save(datasets, 'datasets.pt')

Generating validation split: 0 examples [00:00, ? examples/s]

Generating validation split: 45576 examples [00:00, 121496.76 examples/s]
tokenizing raw datasets (num_proc=64): 100%|██████████| 45576/45576 [00:02<00:00, 17946.73 examples/s]


In [2]:
datasets = torch.load('./threshold/chess/datasets.pt')
import torch
import numpy as np
def get_batch(data, batch_size, block_size):
    start_idxs = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in start_idxs])
    y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in start_idxs])
    return x, y

  datasets = torch.load('./threshold/chess/datasets.pt')


In [4]:
sparsity_level = 0.7
# device = 'cuda:1'
device_2 = 'cpu'
avg_loss = 0.0
n_batch = 64
# accum_steps = 4 
accum_steps = 2
batch_size = 1
block_size = 2048
torch.manual_seed(42)
n_layers = len(model.model.layers)
n_experts = len(model.model.layers[0].block_sparse_moe.experts)

up_proj_states_thresholds = [torch.zeros([n_experts,]) for _ in range(n_layers)]
gate_proj_states_mean_squares = [[torch.zeros(model.config.intermediate_size) for _ in range(n_experts)] for _ in range(n_layers)]

up_states = [[torch.zeros([accum_steps * batch_size * block_size //2, model.config.intermediate_size]) for _ in range(n_experts)] for _ in range(n_layers)]
gate_states = [[torch.zeros([accum_steps * batch_size * block_size //2, model.config.intermediate_size]) for _ in range(n_experts)] for _ in range(n_layers)]

with torch.no_grad():
    for step in range(n_batch // accum_steps):
        print(step * accum_steps)
        all_counts = [0 for _ in range(n_layers * n_experts)]
        for batch_idx in range(accum_steps):
            # print('batch_idx:', batch_idx)
            inputs, labels = get_batch(datasets['validation'], batch_size, block_size)
            inputs = inputs.cuda()
            outputs = model(inputs, labels=inputs)
            avg_loss = avg_loss + outputs.loss / n_batch

            for layer_idx in range(n_layers):
                for expert_idx in range(n_experts):
                    counts = all_counts[layer_idx * n_experts + expert_idx]

                    states = model.model.layers[layer_idx].block_sparse_moe.experts[expert_idx].up_proj_states.reshape(-1, model.config.intermediate_size)
                    cur_counts = states.size(0)
                    # print('counts and cur_counts:',counts, cur_counts)
                    # print(states.size())
                    # print(up_states[layer_idx][expert_idx][counts : counts+cur_counts, :].size())
                    up_states[layer_idx][expert_idx][counts : counts+cur_counts, :] = states

                    states = model.model.layers[layer_idx].block_sparse_moe.experts[expert_idx].gate_proj_states.reshape(-1, model.config.intermediate_size)
                    gate_states[layer_idx][expert_idx][counts : counts+cur_counts, :] = states
                    # counts += cur_counts
                    all_counts[layer_idx * n_experts + expert_idx] += cur_counts

        for layer_idx in range(n_layers):   
            for expert_idx in range(n_experts):
                # print('layer_idx:', layer_idx, 'expert_idx:', expert_idx)
                useful_num = all_counts[layer_idx * n_experts + expert_idx]
                topk_num = int(useful_num * model.config.intermediate_size * sparsity_level)
                up_proj_states_thresholds[layer_idx][expert_idx] += up_states[layer_idx][expert_idx][0:useful_num,:].to(device_2).abs().flatten().kthvalue(topk_num).values.to('cpu')
                gate_proj_states_mean_squares[layer_idx][expert_idx] += (torch.sum(gate_states[layer_idx][expert_idx][0:useful_num,:].to(device_2) ** 2, dim=0).to('cpu') / useful_num).to('cpu')

for layer_idx in range(n_layers):
    for expert_idx in range(n_experts):
        gate_proj_states_mean_squares[layer_idx][expert_idx] /= n_batch // accum_steps
        up_proj_states_thresholds[layer_idx][expert_idx] /= n_batch // accum_steps

0
2
4
6
8
10
12
14
16
18
20
22
24
26
28
30
32
34
36
38
40
42
44
46
48
50
52
54
56
58
60
62


In [5]:
gate_proj_states_mean_squares[0][1],up_proj_states_thresholds[0][1]

(tensor([0., 0., 0.,  ..., 0., 0., 0.]), tensor(0.0481))

In [7]:
importance_thresholds = [torch.zeros([n_experts,]) for _ in range(n_layers)]
up_proj_states_thresholds_2 = [[torch.zeros(model.config.intermediate_size) for _ in range(n_experts)] for _ in range(n_layers)]

with torch.no_grad():
    for step in range(n_batch // accum_steps):
        print(step * accum_steps)
        all_counts = [0 for _ in range(n_layers * n_experts)]
        for batch_idx in range(accum_steps):
            inputs, labels = get_batch(datasets['validation'], batch_size, block_size)
            inputs = inputs.cuda()
            outputs = model(inputs, labels=inputs)
            avg_loss = avg_loss + outputs.loss / n_batch

            for layer_idx in range(n_layers):
                for expert_idx in range(n_experts):
                    counts = all_counts[layer_idx * n_experts + expert_idx]
                    states = model.model.layers[layer_idx].block_sparse_moe.experts[expert_idx].up_proj_states.reshape(-1, states.size(-1))
                    cur_counts = states.size(0)
                    up_states[layer_idx][expert_idx][counts:cur_counts+counts, :] = states
                    # counts += cur_counts
                    all_counts[layer_idx * n_experts + expert_idx] += cur_counts
                
        for layer_idx in range(n_layers):   
            for expert_idx in range(n_experts):
                useful_num = all_counts[layer_idx * n_experts + expert_idx]
                importance_scores = up_states[layer_idx][expert_idx][:useful_num,:] ** 2 * gate_proj_states_mean_squares[layer_idx][expert_idx]
                importance_thresholds[layer_idx][expert_idx] += importance_scores.to(device_2).flatten().kthvalue(int(importance_scores.numel() * sparsity_level)).values.to('cpu')

for layer_idx in range(n_layers):
    for expert_idx in range(n_experts):
        importance_thresholds[layer_idx][expert_idx] /= n_batch // accum_steps
        up_proj_states_thresholds_2[layer_idx][expert_idx] = (importance_thresholds[layer_idx][expert_idx].expand_as(up_proj_states_thresholds_2[layer_idx][expert_idx]) / gate_proj_states_mean_squares[layer_idx][expert_idx]) ** 0.5

thresholds = {'up_proj_states_thresholds': up_proj_states_thresholds, 'up_proj_states_thresholds_2': up_proj_states_thresholds_2}

torch.save(thresholds, f'{save_path}/thresholds.pt')

0
2
4


In [2]:
import torch
save_path = './threshold/c4_mixtral_up'
thresholds = torch.load(f'{save_path}/thresholds.pt')
print(thresholds["up_proj_states_thresholds_2"])

[[tensor([0.0828, 0.0886, 0.0805,  ..., 0.0463, 0.0734, 0.0805]), tensor([0.0914, 0.0872, 0.1076,  ..., 0.0616, 0.0878, 0.0748]), tensor([0.1119, 0.0712, 0.1166,  ..., 0.0903, 0.0871, 0.0887]), tensor([0.1036, 0.0800, 0.1252,  ..., 0.0806, 0.0852, 0.0863]), tensor([0.1034, 0.1043, 0.1129,  ..., 0.0677, 0.0950, 0.0903]), tensor([0.0984, 0.0832, 0.1072,  ..., 0.0709, 0.0874, 0.0822]), tensor([0.0671, 0.1130, 0.1085,  ..., 0.0673, 0.0583, 0.0829]), tensor([0.0929, 0.0765, 0.0864,  ..., 0.0673, 0.0767, 0.0761])], [tensor([0.1306, 0.0817, 0.1159,  ..., 0.0995, 0.1281, 0.1266]), tensor([0.1352, 0.1360, 0.1233,  ..., 0.1193, 0.1400, 0.1366]), tensor([0.1269, 0.1036, 0.1278,  ..., 0.1105, 0.1419, 0.1258]), tensor([0.1210, 0.1199, 0.1266,  ..., 0.1159, 0.1094, 0.1222]), tensor([0.1323, 0.1283, 0.1320,  ..., 0.1297, 0.1526, 0.1441]), tensor([0.1278, 0.1137, 0.1114,  ..., 0.1229, 0.1466, 0.1378]), tensor([0.1113, 0.0857, 0.1023,  ..., 0.0926, 0.1235, 0.1224]), tensor([0.1244, 0.1014, 0.1169,  ...

  thresholds = torch.load(f'{save_path}/thresholds.pt')
