In [None]:
import os
import sys
import shutil

parent_dir = ".."
os.chdir(parent_dir)
import json
import torch
import transformers
from peft import PeftModel

from sparsetral.configuration_sparsetral import SparsetralConfig
from sparsetral.modeling_sparsetral import MistralForCausalLM

trained_weights = "output/checkpoint-7825"
output_dir = "output/sparsetral-16x7B-v2"

# Convert

In [None]:
model_config = SparsetralConfig.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
model_config.pretraining_tp = 1  ## without tensor parallelism rank

# Sparsetral Config
model_config.moe_dtype = "bfloat16"
model_config.adapter_dim = 512
model_config.topk = 4
model_config.moe_scaling = 1
model_config.num_experts = 16
model_config.output_router_logits = False

moe_model = os.path.join(trained_weights, "moe_model.bin")
adapter_model = os.path.join(trained_weights, "adapter_model")

model = MistralForCausalLM.from_pretrained(
    "mistralai/Mistral-7B-Instruct-v0.2",
    config=model_config,
    torch_dtype=torch.bfloat16,
    device_map="cuda:0",
)
model = PeftModel.from_pretrained(model, adapter_model)
model = model.merge_and_unload()

moe_state_dict = torch.load(moe_model, map_location="cpu")
new_moe_state_dict = {}
for k, v in moe_state_dict.items():
    new_moe_state_dict[k.replace("base_model.model.", "")] = v

model.load_state_dict(new_moe_state_dict, strict=False)
tokenizer = transformers.AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

config_path = os.path.join(output_dir, "config.json")
config = json.load(open(config_path, "r"))
config["architectures"] = ["modeling_sparsetral.MistralForCausalLM"]
config["auto_map"] = {
    "AutoConfig": "configuration_sparsetral.SparsetralConfig",
    "AutoModel": "modeling_sparsetral.MistralModel",
    "AutoModelForCausalLM": "modeling_sparsetral.MistralForCausalLM"
  }
config["model_type"] = "sparsetral"
config.pop("_name_or_path", None)
json.dump(config, open(config_path, "w"), indent=2)

shutil.copy2("sparsetral/configuration_sparsetral.py", os.path.join(output_dir, "configuration_sparsetral.py"))
shutil.copy2("sparsetral/modeling_sparsetral.py", os.path.join(output_dir, "modeling_sparsetral.py"))

# Push to Hub (Optional)

In [None]:
from huggingface_hub import HfApi
api = HfApi()

api.upload_folder(
    folder_path=output_dir,
    repo_id="",
    repo_type="model",
    token="" # needs write access
)

# Load

In [None]:
model = transformers.AutoModelForCausalLM.from_pretrained("serpdotai/sparsetral-16x7B-v2", torch_dtype=torch.bfloat16, device_map="cuda:0", trust_remote_code=True)
tokenizer = transformers.AutoTokenizer.from_pretrained("serpdotai/sparsetral-16x7B-v2")

# Inference

In [None]:
system_str = "<|im_start|>system\n{message}<|im_end|>\n"
user_str = "<|im_start|>user\n{message}<|im_end|>\n"
assistant_str = "<|im_start|>assistant\n{message}<|im_end|>\n"

def construct_prompt(messages):
    prompt = ""
    for message in messages:
        if message["from"] in ["human", "user"]:
            prompt += user_str.format(
                message=message["value"]
            )
        elif message["from"] in ["gpt", "assistant"]:
            prompt += assistant_str.format(
                message=message["value"]
            )
        elif message["from"] in ["system", "instruction"]:
            prompt += system_str.format(
                message=message["value"]
            )
        else:
            raise ValueError(
                f"Unknown message type: {message['from']}"
            )
    return prompt + "<|im_start|>assistant\n"

In [None]:
system = "You are a helpful assistant who will help the user to the best of their ability. If you don't know something, say \"I don't know\""
user = "Are you sentient?"

messages = [
    {"from": "system", "value": system},
    {"from": "user", "value": user},
]

prompt = construct_prompt(messages)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = inputs.to(model.device)
pred = model.generate(**inputs, max_length=4096, do_sample=True, top_k=50, top_p=0.99, temperature=0.9, num_return_sequences=1)
print(tokenizer.decode(pred.cpu()[0], skip_special_tokens=True))