Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

在 不同LoRA 适配器之间切换? #75

Open
ziwang-com opened this issue Jun 5, 2023 · 0 comments
Open

在 不同LoRA 适配器之间切换? #75

ziwang-com opened this issue Jun 5, 2023 · 0 comments

Comments

@ziwang-com
Copy link
Owner

Lightning-AI/lit-llama#193

有没有办法在 LoRA 适配器之间切换?即为不同的任务加载几个不同的适配器,并在它们之间快速切换以执行不同的任务?尽可能使用peft库。

@totaltube我认为我们不支持这一点。你能举个例子吗?请注意,LoRA 和适配器是两个不同的东西.

例如, 要从 LoRA 切换到适配器微调, 您必须 1.删除上下文管理器 2。将模型类替换为 和 3。更改为 。with lora()lit_llama.model.LLaMAlit_llama.adapter.LLaMAmark_only_lora_as_trainablemark_mark_only_adapter_as_trainable(model)

我认为目标不是从 lora 切换到适配器, 而是在同一基本模型之上切换不同的 lora 权重

正确。
这是另一个库中的示例:https://github.com/huggingface/peft/blob/main/examples/multi_adapter_examples/PEFT_Multi_LoRA_Inference.ipynb

目标是 - 我可以独立地为不同的任务微调模型并快速切换权重。无需在扩展数据集上重新训练,例如,扩展数据集包括另一个任务。

关于如何实现@awaelchli所说的内容的示例类:

class FinetunedAdapter:
from lit_llama.adapter import LLaMA, LLaMAConfig

def __init__(
    self,
    adapter_path: Optional[Path] = None,
    pretrained_path: Optional[Path] = None,
    tokenizer_path: Optional[Path] = None,
    quantize: Optional[str] = None,
) -> None:
    if not adapter_path:
        adapter_path = Path("out/adapter/alpaca/lit-llama-adapter-finetuned.pth")
    if not pretrained_path:
        pretrained_path = Path("./checkpoints/lit-llama/7B/lit-llama.pth")
    if not tokenizer_path:
        tokenizer_path = Path("./checkpoints/lit-llama/tokenizer.model")

    assert adapter_path.is_file()
    assert pretrained_path.is_file()
    assert tokenizer_path.is_file()

    self.fabric = L.Fabric(devices=1)
    dtype = (
        torch.bfloat16
        if self.fabric.device.type == "cuda" and torch.cuda.is_bf16_supported()
        else torch.float32
    )

    with EmptyInitOnDevice(
        device=self.fabric.device, dtype=dtype, quantization_mode=quantize
    ):
        self.model = self.LLaMA(self.LLaMAConfig())

    # 1. Load the pretrained weights
    pretrained_checkpoint = lazy_load(pretrained_path)
    self.model.load_state_dict(pretrained_checkpoint, strict=False)

    # 2. Load the fine-tuned adapter weights
    adapter_checkpoint = lazy_load(adapter_path)
    self.model.load_state_dict(adapter_checkpoint, strict=False)

    self.model.eval()
    self.model = self.fabric.setup_module(self.model)

    self.tokenizer = Tokenizer(tokenizer_path)

def load_adapter(self, adapter_path: Path):
    assert adapter_path.is_file()

    adapter_checkpoint = lazy_load(adapter_path)
    self.model.load_state_dict(adapter_checkpoint, strict=False)

def generate(
    self,
    instruction: str = "",
    input_text: str = "",
    max_new_tokens: int = 100,
    top_k: int = 200,
    temperature: float = 0.8,
    use_instruction: bool = True,
):
    if use_instruction:
        sample = {"instruction": instruction, "input": input_text}
        prompt = generate_prompt(sample)
    else:
        assert input_text, "input_text must be provided if use_prompt is False."
        assert (
            len(instruction) == 0
        ), "instruction must be empty if use_prompt is False."
        prompt = generate_no_prompt(input_text)

    encoded = self.tokenizer.encode(
        prompt, bos=True, eos=False, device=self.model.device
    )

    output = generate(
        self.model,
        idx=encoded,
        max_seq_length=max_new_tokens,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=top_k,
        eos_id=self.tokenizer.eos_id,
    )

    output = self.tokenizer.decode(output)
    output = output.split("### Response:")[1].strip()
    return output

然后:

adapter = FinetunedAdapter(adapter_path=adapther_path)
adapter.load_adapter(checkpoint)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant