In [26]:
import argparse
import json
import os
import shutil
import warnings

import torch
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer

In [27]:
try:
    from transformers import LlamaTokenizerFast
except ImportError as e:
    warnings.warn(e)
    warnings.warn("Failed to import LlamaTokenizerFast")
    LlamaTokenizerFast = None 

In [28]:
### some helper functions ###
NUM_SHARDS = {
    "7B": 1,
    "7Bf": 1,
    "13B": 2,
    "13Bf": 2,
    "34B": 4,
    "30B": 4,
    "65B": 8,
    "70B": 8,
    "70Bf": 8,
}

def compute_intermediate_size(n , ffn_dim_multiplier = 1, multiple_of=256):
    return multiple_of *((int(ffn_dim_multiplier * int(8*n/3)) + multiple_of - 1) // multiple_of)

def read_json(path):
    with open(path, 'r') as f:
        return json.load(f)
    
def write_json(text, path):
    with open(path, 'w') as f:
        json.dump(text, f)

In [29]:
model_path = "test_path" # Output directory
input_base_path = "llama-2-7b"  # input directory
model_size = "7B" # size of the model
safe_serialization= False # Whether or not to save using `safetensors`
tokenizer_path= None #path to tokenizer
llama_version = 1 # either Llama1 ort Llama2


In [38]:
###### Write model into hugging face format. ######
if not os.path.isfile(os.path.join(input_base_path, "params.json")):
    os.path.join(input_base_path, model_size)

os.makedirs(model_path, exist_ok=True)
tmp_model_path = os.path.join(model_path, "tmp")
os.makedirs(tmp_model_path, exist_ok=True)

params = read_json(os.path.join(input_base_path, "params.json"))
num_shards = NUM_SHARDS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = params.get("rope_theta", 10000.0)   # this has something to do with RoPE
inv_freq = 1.0 / (base ** (torch.arange(0, 2 * n_heads_per_shard, 2).float() / dim))

if base > 10000:
    max_positional_embedding = 16348
else:
    if llama_version == 1:
        max_positional_embedding = 2048
    elif llama_version == 2:
        max_positional_embedding = 4096
    else:
        raise NotImplementedError(
            f"Version {llama_version} of llama is not supported yet. "
            "Current supported versions of llama are [1, 2]."
        )
    

tokenizer_class = LlamaTokenizerFast if LlamaTokenizerFast else LlamaTokenizer
if tokenizer_path is not None:
    tokenizer = tokenizer_class(tokenizer_path) 
    tokenizer.save_pretrained(tmp_model_path)
vocab_size = tokenizer.vocab_size if tokenizer_path is not None else 32000

if params.get("n_kv_heads", None) is not None:
    num_key_value_heads = params["n_kv_heads"]
    num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
    key_value_dim = dim // num_key_value_heads
else:
    num_key_value_heads = n_heads
    num_local_key_value_heads = n_heads_per_shard
    key_value_dim = dim


# permute for sliced rotary
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
    return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)


if num_shards == 1:
    loaded = torch.load(os.path.join(input_base_path,  "consolidated.00.pth"), map_location="cpu")



In [46]:
import networkx as nx
import matplotlib.pyplot as plt

# Load the model state dictionary
model_state_dict = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")

# Print out layer names and shapes
# for name, param in model_state_dict.items():
#     print(name, param.shape)

In [58]:
from collections import defaultdict
# Organize layers into a dictionary
layers_dict = defaultdict(list)
for name, param in model_state_dict.items():
    print(name)
    try:
        layer_index = int(name.split('.')[1])  # Extract the layer index from the parameter name
    except:
        continue
    # if layer_index not in layers_dict:
    #     layers_dict[layer_index] = {}
    layers_dict[layer_index].append(name)



tok_embeddings.weight
norm.weight
output.weight
layers.0.attention.wq.weight
layers.0.attention.wk.weight
layers.0.attention.wv.weight
layers.0.attention.wo.weight
layers.0.feed_forward.w1.weight
layers.0.feed_forward.w2.weight
layers.0.feed_forward.w3.weight
layers.0.attention_norm.weight
layers.0.ffn_norm.weight
layers.1.attention.wq.weight
layers.1.attention.wk.weight
layers.1.attention.wv.weight
layers.1.attention.wo.weight
layers.1.feed_forward.w1.weight
layers.1.feed_forward.w2.weight
layers.1.feed_forward.w3.weight
layers.1.attention_norm.weight
layers.1.ffn_norm.weight
layers.2.attention.wq.weight
layers.2.attention.wk.weight
layers.2.attention.wv.weight
layers.2.attention.wo.weight
layers.2.feed_forward.w1.weight
layers.2.feed_forward.w2.weight
layers.2.feed_forward.w3.weight
layers.2.attention_norm.weight
layers.2.ffn_norm.weight
layers.3.attention.wq.weight
layers.3.attention.wk.weight
layers.3.attention.wv.weight
layers.3.attention.wo.weight
layers.3.feed_forward.w1.weight


In [59]:
layers_dict

defaultdict(list,
            {0: ['layers.0.attention.wq.weight',
              'layers.0.attention.wk.weight',
              'layers.0.attention.wv.weight',
              'layers.0.attention.wo.weight',
              'layers.0.feed_forward.w1.weight',
              'layers.0.feed_forward.w2.weight',
              'layers.0.feed_forward.w3.weight',
              'layers.0.attention_norm.weight',
              'layers.0.ffn_norm.weight'],
             1: ['layers.1.attention.wq.weight',
              'layers.1.attention.wk.weight',
              'layers.1.attention.wv.weight',
              'layers.1.attention.wo.weight',
              'layers.1.feed_forward.w1.weight',
              'layers.1.feed_forward.w2.weight',
              'layers.1.feed_forward.w3.weight',
              'layers.1.attention_norm.weight',
              'layers.1.ffn_norm.weight'],
             2: ['layers.2.attention.wq.weight',
              'layers.2.attention.wk.weight',
              'layers.2.attention.wv.

In [61]:
# what is the meanin
model_state_dict

{'tok_embeddings.weight': tensor([[ 1.2293e-06, -1.8179e-06, -4.3511e-06,  ...,  8.7172e-07,
          -6.5267e-06,  8.9034e-07],
         [ 1.8616e-03, -3.3722e-03,  3.9864e-04,  ..., -8.3008e-03,
           2.5787e-03, -3.9368e-03],
         [ 1.0986e-02,  9.8877e-03, -5.0964e-03,  ...,  2.5177e-03,
           7.7057e-04, -5.0049e-03],
         ...,
         [-1.3977e-02, -2.7313e-03, -1.9897e-02,  ..., -1.0437e-02,
           9.5825e-03, -1.8005e-03],
         [-1.0742e-02,  9.3384e-03,  1.2939e-02,  ..., -3.3203e-02,
          -1.6357e-02,  3.3875e-03],
         [-8.3008e-03, -4.0588e-03, -1.1063e-03,  ...,  3.4790e-03,
          -1.2939e-02,  3.1948e-05]], dtype=torch.bfloat16),
 'norm.weight': tensor([1.8672, 1.8672, 1.8047,  ..., 1.7188, 1.8281, 1.6016],
        dtype=torch.bfloat16),
 'output.weight': tensor([[-0.0039,  0.0032, -0.0071,  ...,  0.0053, -0.0082,  0.0070],
         [-0.0315,  0.0466, -0.0023,  ..., -0.0211,  0.0173,  0.0334],
         [-0.0125,  0.0036,  0.0195,  

ModuleNotFoundError: No module named 'torchviz'