Given a directory, converts all peft checkpoints to standard huggingface `.from_pretrained()` ones. This only looks at files in the given directory, doesn't look recursively. Why is this useful? If you directly load from peft checkpoints, you can't load the model in 8 bits.

Assumes:
- All checkpoints specified in `ckpt_dir` are from the same model architecture
- All checkpoints are formatted "../models/adapter_model*.bin"
- There is no ../models/tmp/ directory

In [None]:
ckpt_dir = "../models/"

In [None]:
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
TRUE_LABEL_STR = "True"
FALSE_LABEL_STR = "False"
id2label = {0: FALSE_LABEL_STR, 1: TRUE_LABEL_STR}
label2id = {FALSE_LABEL_STR: 0, TRUE_LABEL_STR: 1}

In [None]:
from transformers import (
    AutoTokenizer,
    GPTNeoForSequenceClassification,
    LlamaTokenizer,
    LlamaForSequenceClassification,
)


def load_model(model_type):
    if model_type == "neo":
        model_name = "EleutherAI/gpt-neo-1.3B"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = GPTNeoForSequenceClassification.from_pretrained(
            model_name,
            num_labels=2,
            id2label=id2label,
            label2id=label2id,
            use_auth_token=True,
        )

        tokenizer.add_special_tokens({"pad_token": "<PAD>"})
        model.config.pad_token_id = tokenizer.pad_token_id
        model.resize_token_embeddings(len(tokenizer))
    elif model_type == "llama":
        model_name = "meta-llama/Llama-2-13b-hf"
        tokenizer = LlamaTokenizer.from_pretrained(model_name, use_auth_token=True)
        model = LlamaForSequenceClassification.from_pretrained(
            model_name,
            num_labels=2,
            id2label=id2label,
            label2id=label2id,
            use_auth_token=True,
        )

        # This is automatically done otherwise
        if not int8_training:
            model = model.to(device)

        tokenizer.add_special_tokens({"pad_token": "<PAD>"})
        model.config.pad_token_id = tokenizer.pad_token_id
        model.resize_token_embeddings(len(tokenizer))
    else:
        raise Exception("Use one of the model types")

    return model

In [None]:
from peft import PeftModel


def convert_to_peft(model, model_id):
    model = PeftModel.from_pretrained(model, model_id=model_id)
    model = model.merge_and_unload()
    return model

## Convert Found Checkpoints

In [None]:
import glob
import os


peft_config_path = os.path.join(ckpt_dir, "adapter_config.json")
assert os.path.exists(peft_config_path)  # Should only be one config with this name

found_peft_ckpts = glob.glob(os.path.join(ckpt_dir, "*.bin"), recursive=False)
found_peft_ckpts = [i.split(ckpt_dir)[1] for i in found_peft_ckpts]
found_peft_ckpts

In [None]:
from tqdm import tqdm
import os
import shutil


# Make tmp dir and copy config to it
os.makedirs(os.path.join(ckpt_dir, "tmp"), exist_ok=True)
shutil.copyfile(
    os.path.join(ckpt_dir, "adapter_config.json"),
    os.path.join(ckpt_dir, "tmp/adapter_config.json"),
)

model_type = "neo"  # "neo" | "llama"
for i in tqdm(found_peft_ckpts):
    model_name_extension = i.split("adapter_model")[1].split(".bin")[0]  # Get epoch no.

    # Move ckpt to tmp
    model_orig_path = os.path.join(ckpt_dir, i)
    model_tmp_path = os.path.join(ckpt_dir, "tmp/adapter_model.bin")
    os.rename(model_orig_path, model_tmp_path)

    # Load model and save
    model = load_model(model_type)
    model = convert_to_peft(model, os.path.join(ckpt_dir, "tmp"))
    model.save_pretrained(os.path.join(ckpt_dir, "tmp"))

    # Move new ckpt and config back
    new_ckpt_name = "pytorch_model" + model_name_extension + ".bin"
    os.rename(
        os.path.join(ckpt_dir, "tmp/pytorch_model.bin"),
        os.path.join(ckpt_dir, new_ckpt_name),
    )
    os.rename(
        os.path.join(ckpt_dir, "tmp/config.json"), os.path.join(ckpt_dir, "config.json")
    )

    # Delete peft ckpt
    os.remove(model_tmp_path)

    print(f"Converted {i} to {new_ckpt_name}")
    del model

shutil.rmtree(os.path.join(ckpt_dir, "tmp"))