In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
%cd server-dev

/usr/src/server-dev


In [3]:
!make gen-server

# Compile protos
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
[0mmkdir text_generation_server/pb || true
mkdir: cannot create directory ‘text_generation_server/pb’: File exists
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb \
	--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/generate.proto
  import pkg_resources
Writing mypy to generate_pb2.pyi
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
text_generation_server/pb/generate_pb2.py text_generation_server/pb/generate_pb2_grpc.py text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py


In [4]:
from text_generation_server.models.flash_llama import FlashLlama

model = FlashLlama(model_id="meta-llama/Llama-2-7b-hf")


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /opt/conda/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda118.so
CUDA SETUP: CUDA runtime path found: /opt/conda/lib/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /opt/conda/lib/python3.9/site-packages/bitsandbytes/libbitsandbytes_cuda118.so...


You are using a model of type llama to instantiate a model of type . This is not supported for all configurations of models and can yield errors.


In [5]:
import torch
from peft.tuners.lora import LoraConfig
from text_generation_server.utils.layers import SuperLayer, TensorParallelColumnLinear, TensorParallelRowLinear
from text_generation_server.utils import (
    weight_files,
    Weights,
)
from typing import Dict, List, Tuple

class BLoraConfig:
    def __init__(
        self,
        lora_id: str,
        lora_r: int,
        lora_alpha: int,
        weights: Weights,
    ):
        self.lora_id = lora_id
        self.lora_r = lora_r
        self.lora_alpha = lora_alpha
        self.weights = weights

class BLoraLinear(torch.nn.Module):
    def __init__(self, linear, r, target_modules) -> None:
        super().__init__()
        self.linear = linear
        self.r = r
        self.target_modules = target_modules

        # adapter weights
        self.lora_ids = {target_module: set() for target_module in self.target_modules}
        self.scales = {target_module: {} for target_module in self.target_modules}
        self.lora_A = {target_module: {} for target_module in self.target_modules}
        self.lora_B = {target_module: {} for target_module in self.target_modules}

        # adapter weights in batch format
        self.batch_lora_ids = {target_module: [] for target_module in self.target_modules}
        self.scales_batch = {target_module: None for target_module in self.target_modules}
        self.lora_A_batch = {target_module: None for target_module in self.target_modules}
        self.lora_B_batch = {target_module: None for target_module in self.target_modules}

    def load_adapter(
        self, 
        lora_id: str, 
        lora_alpha: int, 
        weights: Dict[str, Tuple[torch.Tensor, torch.Tensor]],
    ):
        # confirm adapters passed are all Wq,Wk,Wv
        if len(weights) != len(self.target_modules):
            raise NotImplementedError("Currently require adapter for all of sub-matrices")
        
        # actually load the data
        for target_module in weights:
            if target_module not in self.target_modules:
                raise ValueError(f"Module passed to load_adapter must be in {self.target_modules}")
            if lora_id in self.lora_ids[target_module]:
                raise ValueError(f"{lora_id} already loaded into this module")
            
            self.lora_ids[target_module].add(lora_id)
            self.scales[target_module][lora_id] = lora_alpha / self.r
            self.lora_A[target_module][lora_id] = weights[target_module][0].T
            self.lora_B[target_module][lora_id] = weights[target_module][1].T
    
    def set_batch_lora_ids(self, lora_ids: List[str]):
        for target_module in self.target_modules:
            for lora_id in lora_ids:
                if lora_id not in self.lora_ids[target_module]:
                    raise NotImplementedError("Not yet handling some items in batch not having an adapter")
                
            self.batch_lora_ids[target_module] = lora_ids

        # create the tensors [lora_b, W]
        # TODO: figure out how to get this on the right device in sharded mode
        for target_module in self.target_modules:
            self.lora_A_batch[target_module] = torch.stack([self.lora_A[target_module][lora_id] for lora_id in self.batch_lora_ids])
            self.lora_B_batch[target_module] = torch.stack([self.lora_B[target_module][lora_id] for lora_id in self.batch_lora_ids])
            self.scales_batch[target_module] = torch.tensor([self.scales[target_module][lora_id] for lora_id in self.batch_lora_ids]).reshape(-1,1,1)

    def forward(self, x: torch.Tensor):
        previous_dtype = x.dtype
        
        # xW
        out = self.linear(x)

        # xAB
        for target_module in enumerate(self.target_modules):
            if x.shape[0] != len(self.batch_lora_ids[target_module]):
                raise NotImplementedError("Not yet handling some items in batch not having an adapter")
            self.lora_forward(out, x)
        
        return out.to(previous_dtype)
    
    def lora_forward(self, out: torch.Tensor, x: torch.Tensor, target_module: str):
        out += torch.bmm(torch.bmm(x, self.lora_A_batch[target_module]), self.lora_A_batch[target_module])

class BLoraLinearQKV(BLoraLinear):
    def __init__(self, linear, r, target_modules=["q_proj", "k_proj", "v_proj"]) -> None:
        super().__init__(linear, r, target_modules)

        # get endpoints of Wq, Wk, Wv
        combined_width = self.linear.weight.shape[0]
        if combined_width != 4096 * 3:
            raise NotImplementedError("Currently requires all Wq, Wk, Wk to be the same size")
        width = combined_width // 3
        self.start_out_indexes = {target_module: idx * width for idx, target_module in enumerate(self.target_modules)}
        self.end_out_indexes = {target_module: (idx + 1) * width for idx, target_module in enumerate(self.target_modules)}

    def lora_forward(self, out: torch.Tensor, x: torch.Tensor, target_module: str):
        start = self.start_out_indexes[target_module]
        end = self.end_out_indexes[target_module]
        
        out[:, start: end] += torch.bmm(torch.bmm(x, self.lora_A_batch[target_module]), self.lora_A_batch[target_module])

class BLoraTensorParallelColumnLinear(SuperLayer):
    def __init__(self, linear):
        super().__init__(linear)

    @classmethod
    def from_linear(
        cls, 
        linear: TensorParallelColumnLinear,
        prefix: str,
        lora_r: int,
        lora_configs: List[BLoraConfig],
        target_modules: List[str],
    ):
        # SETUP WRAPPER
        blora_linear = BLoraLinearQKV(
            linear=linear.linear,
            r=lora_r,
            target_modules=target_modules
        )

        # LOAD WEIGHTS INTO MEMORY
        for lora_config in lora_configs:
            adapter_weights = {}

            for target_module in target_modules:
                weight_A = lora_config.weights.get_multi_weights_col(
                    prefixes=[f"base_model.model.{prefix}.{target_module}.lora_A"], 
                    quantize=None,
                    dim=0
                )
                weight_B = lora_config.weights.get_multi_weights_col(
                    prefixes=[f"base_model.model.{prefix}.{target_module}.lora_B"],
                    quantize=None,
                    dim=0
                )

                adapter_weights[target_module] = (weight_A, weight_B)
            
            if lora_r != lora_config.lora_r:
                raise ValueError("All LORA adapters must have the same rank")
            
            # SETUP ADAPTER
            blora_linear.load_adapter(
                lora_id=lora_config.lora_id,
                lora_alpha=lora_config.lora_alpha,
                weights=adapter_weights,
            )
        
        return cls(blora_linear)
    
class BLoraTensorParallelRowLinear(SuperLayer):
    def __init__(self, linear, process_group):
        if process_group.size() > 1:
            raise NotImplementedError("Currently not supporting sharded")
        
        super().__init__(linear)
        self.process_group = process_group

    @classmethod
    def from_linear(
        cls, 
        linear: TensorParallelRowLinear,
        prefix: str,
        lora_r: int,
        lora_configs: List[BLoraConfig],
        target_modules: List[str],
    ):  
        # SETUP WRAPPER
        blora_linear = BLoraLinear(
            linear=linear.linear,
            r=lora_r,
            target_modules=target_modules
        )

        # LOAD WEIGHTS INTO MEMORY
        for lora_config in lora_configs:
            adapter_weights = {}

            for target_module in target_modules:
                weight_A = lora_config.weights.get_multi_weights_row(
                    prefix=f"base_model.model.{prefix}.{target_module}.lora_A", 
                    quantize=None
                )
                weight_B = lora_config.weights.get_multi_weights_row(
                    prefix=f"base_model.model.{prefix}.{target_module}.lora_B", 
                    quantize=None
                )

                adapter_weights[target_module] = (weight_A, weight_B)

            if lora_r != lora_config.lora_r:
                raise ValueError("All LORA adapters must have the same rank")
            
            # SETUP ADAPTER
            blora_linear.load_adapter(
                lora_id=lora_config.lora_id,
                lora_alpha=lora_config.lora_alpha,
                weights=adapter_weights,
            )
        
        return cls(blora_linear, process_group=linear.process_group)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        out = super().forward(input)
        if self.process_group.size() > 1:
            torch.distributed.all_reduce(out, group=self.process_group)
        return out

In [6]:
class BLoraFlashLlama:
    def __init__(
        self,
        model,
        lora_configs: Dict[str, LoraConfig],
        lora_r=16,
    ):
        self.model = model
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

        # format blora configs
        blora_configs = []
        for lora_id, lora_config in lora_configs.items():    
            # error checking
            if set(lora_config.target_modules) != set(target_modules):
                raise NotImplementedError(
                    """
                    Currently require lora adapters on exactly {self.target_modules}
                    """
                )
            
            if lora_config.r != lora_r:
                raise ValueError(
                    """
                    Currently require all lora adapters to have the same r. lora_config.r={lora_config.r} / lora_r ={lora_r}
                    """
                )

            filenames = weight_files(lora_id, extension=".safetensors")
            if len(filenames) < 1:
                raise ValueError(
                    """
                    Weight files not found for LORA adapter. Make sure you download with 
                    text-generation-server download-weights {lora_id}
                    """
                )
            
            # unpack configurations 
            blora_configs.append(BLoraConfig(
                lora_id=lora_id,
                lora_r=lora_config.r,
                lora_alpha=lora_config.lora_alpha,
                weights=Weights(
                    filenames, 
                    self.model.device, 
                    dtype=self.model.dtype, 
                    process_group=self.model.process_group
                ),
            ))
        
        # update layers
        for layer_id, layer in enumerate(self.model.model.model.layers):
            prefix = f"model.layers.{layer_id}.self_attn"

            # update q_proj, k_proj, v_proj
            if not isinstance(layer.self_attn.query_key_value, TensorParallelColumnLinear):
                print(layer.self_attn.query_key_value)
                raise ValueError("Expected query_key_value to be TensorParallelColumnLinear")

            layer.self_attn.query_key_value = BLoraTensorParallelColumnLinear.from_linear(
                linear=layer.self_attn.query_key_value,
                prefix=prefix,
                lora_r=lora_r,
                lora_configs=blora_configs,
                target_modules=["q_proj", "k_proj", "v_proj"]
            )

            # update o_proj
            if not isinstance(layer.self_attn.o_proj, TensorParallelRowLinear):
                print(layer)
                raise ValueError("Expected o_proj to be TensorParallelRowLinear")
            
            layer.self_attn.o_proj = BLoraTensorParallelRowLinear.from_linear(
                linear=layer.self_attn.o_proj,
                prefix=prefix,
                lora_r=lora_r,
                lora_configs=blora_configs,
                target_modules=["o_proj"],
            )

In [7]:
lora_id = "nealchandra/llama-2-7b-hf-lora-alpaca-json"
lora_config = LoraConfig.from_pretrained(lora_id)

blora_llama = BLoraFlashLlama(model, {lora_id:lora_config})

In [13]:
print(blora_llama.model.model.model.layers[0].self_attn.o_proj.linear.lora_B["o_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)

torch.Size([16, 4096])


In [14]:
print(blora_llama.model.model.model.layers[0].self_attn.o_proj.linear.lora_A["o_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)
print(blora_llama.model.model.model.layers[0].self_attn.o_proj.linear.lora_B["o_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)

print(blora_llama.model.model.model.layers[0].self_attn.query_key_value.linear.lora_A["q_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)
print(blora_llama.model.model.model.layers[0].self_attn.query_key_value.linear.lora_B["q_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)

print(blora_llama.model.model.model.layers[0].self_attn.query_key_value.linear.lora_A["k_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)
print(blora_llama.model.model.model.layers[0].self_attn.query_key_value.linear.lora_B["k_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)

print(blora_llama.model.model.model.layers[0].self_attn.query_key_value.linear.lora_A["v_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)
print(blora_llama.model.model.model.layers[0].self_attn.query_key_value.linear.lora_B["v_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)

torch.Size([4096, 16])
torch.Size([16, 4096])
torch.Size([4096, 16])
torch.Size([16, 4096])
torch.Size([4096, 16])
torch.Size([16, 4096])
torch.Size([4096, 16])
torch.Size([16, 4096])
