In [1]:
import os
os.environ["CUDA_MEMORY_FRACTION"] = "0.97"

In [2]:
%cd server-dev
!make gen-server

/usr/src/server-dev
# Compile protos
pip install grpcio-tools==1.51.1 mypy-protobuf==3.4.0 'types-protobuf>=3.20.4' --no-cache-dir
[0mmkdir text_generation_server/pb || true
mkdir: cannot create directory ‘text_generation_server/pb’: File exists
python -m grpc_tools.protoc -I../proto --python_out=text_generation_server/pb \
	--grpc_python_out=text_generation_server/pb --mypy_out=text_generation_server/pb ../proto/generate.proto
  import pkg_resources
Writing mypy to generate_pb2.pyi
find text_generation_server/pb/ -type f -name "*.py" -print0 -exec sed -i -e 's/^\(import.*pb2\)/from . \1/g' {} \;
text_generation_server/pb/generate_pb2.py text_generation_server/pb/generate_pb2_grpc.py text_generation_server/pb/__init__.py touch text_generation_server/pb/__init__.py


In [3]:
from peft import LoraConfig
from text_generation_server.utils.blora import BLoraConfig, BLoraTensorParallelColumnLinear, BLoraTensorParallelRowLinear
from text_generation_server.utils import weight_files, Weights
from text_generation_server.utils.layers import TensorParallelColumnLinear, TensorParallelRowLinear
from typing import Dict

class BLoraFlashLlama:
    def __init__(
        self,
        model,
        lora_configs: Dict[str, LoraConfig],
        lora_r=16,
    ):
        self.model = model
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

        # format blora configs
        blora_configs = []
        for lora_id, lora_config in lora_configs.items():    
            # error checking
            if set(lora_config.target_modules) != set(target_modules):
                raise NotImplementedError(
                    """
                    Currently require lora adapters on exactly {self.target_modules}
                    """
                )
            
            if lora_config.r != lora_r:
                raise ValueError(
                    """
                    Currently require all lora adapters to have the same r. lora_config.r={lora_config.r} / lora_r ={lora_r}
                    """
                )

            filenames = weight_files(lora_id, extension=".safetensors")
            if len(filenames) < 1:
                raise ValueError(
                    """
                    Weight files not found for LORA adapter. Make sure you download with 
                    text-generation-server download-weights {lora_id}
                    """
                )
            
            # unpack configurations 
            blora_configs.append(BLoraConfig(
                lora_id=lora_id,
                lora_r=lora_config.r,
                lora_alpha=lora_config.lora_alpha,
                weights=Weights(
                    filenames, 
                    self.model.device, 
                    dtype=self.model.dtype, 
                    process_group=self.model.process_group
                ),
            ))
        
        # update layers
        for layer_id, layer in enumerate(self.model.model.model.layers):
            prefix = f"model.layers.{layer_id}.self_attn"

            # update q_proj, k_proj, v_proj
            if not isinstance(layer.self_attn.query_key_value, TensorParallelColumnLinear):
                print(layer.self_attn.query_key_value)
                raise ValueError("Expected query_key_value to be TensorParallelColumnLinear")

            layer.self_attn.query_key_value = BLoraTensorParallelColumnLinear.from_linear(
                linear=layer.self_attn.query_key_value,
                prefix=prefix,
                lora_r=lora_r,
                lora_configs=blora_configs,
                target_modules=["q_proj", "k_proj", "v_proj"]
            )

            # update o_proj
            if not isinstance(layer.self_attn.o_proj, TensorParallelRowLinear):
                print(layer)
                raise ValueError("Expected o_proj to be TensorParallelRowLinear")
            
            layer.self_attn.o_proj = BLoraTensorParallelRowLinear.from_linear(
                linear=layer.self_attn.o_proj,
                prefix=prefix,
                lora_r=lora_r,
                lora_configs=blora_configs,
                target_modules=["o_proj"],
            )
    
    def set_batch_ids(self, lora_ids, cu_seqlen_prefill):
        for layer in self.model.model.model.layers:
            layer.self_attn.query_key_value.linear.set_batch_lora_ids(lora_ids, cu_seqlen_prefill)
            layer.self_attn.o_proj.linear.set_batch_lora_ids(lora_ids, cu_seqlen_prefill)


In [4]:
from text_generation_server.models.flash_llama import FlashLlama
from peft import LoraConfig
import torch

model_id = "meta-llama/Llama-2-7b-hf"
lora_id = "nealchandra/llama-2-7b-hf-lora-alpaca-json"

model = FlashLlama(model_id=model_id, dtype=torch.float16)

You are using a model of type llama to instantiate a model of type . This is not supported for all configurations of models and can yield errors.


In [5]:
blora_llama = BLoraFlashLlama(model, {lora_id: LoraConfig.from_pretrained(lora_id)})

print(blora_llama.model.model.model.layers[0].self_attn.o_proj.linear.lora_A["o_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)
print(blora_llama.model.model.model.layers[0].self_attn.o_proj.linear.lora_B["o_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)
print(blora_llama.model.model.model.layers[0].self_attn.query_key_value.linear.lora_A["q_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)
print(blora_llama.model.model.model.layers[0].self_attn.query_key_value.linear.lora_B["q_proj"]["nealchandra/llama-2-7b-hf-lora-alpaca-json"].shape)

torch.Size([4096, 16])
torch.Size([16, 4096])
torch.Size([4096, 16])
torch.Size([16, 4096])


In [6]:
from text_generation_server.models.flash_causal_lm import FlashCausalLMBatch
from text_generation_server.pb import generate_pb2

max_input_length = 128
max_batch_size = 2
max_prefill_tokens = max_input_length * max_batch_size - 32

warmup_requests = []
n_tokens = 0
while n_tokens < max_prefill_tokens:
    warmup_requests.append(
        generate_pb2.Request(
            id=0,
            inputs="_text" * max_input_length,
            truncate=min(max_input_length, max_prefill_tokens - n_tokens),
            parameters=generate_pb2.NextTokenChooserParameters(
                do_sample=False
            ),
            stopping_parameters=generate_pb2.StoppingCriteriaParameters(
                max_new_tokens=2
            )
        ),
    )
    
    n_tokens += max_input_length

warmup_batch = generate_pb2.Batch(id=0, requests=warmup_requests, size=len(warmup_requests))

fclm_warmup_batch = FlashCausalLMBatch.from_pb(
    pb=warmup_batch,
    tokenizer=model.tokenizer,
    dtype=model.dtype,
    device=model.device,
)

blora_llama.set_batch_ids([lora_id] * max_batch_size, cu_seqlen_prefill=fclm_warmup_batch.cu_seqlen_prefill)
max_supported_total_tokens = blora_llama.model.warmup(batch=fclm_warmup_batch)

In [11]:
parameters = generate_pb2.NextTokenChooserParameters(
    watermark=False,
    temperature=1.0,
    repetition_penalty=1.0,
    top_k=0,
    top_p=1.0,
    typical_p=1.0,
    do_sample=False
)

stopping_parameters = generate_pb2.StoppingCriteriaParameters(
    max_new_tokens=100,
    ignore_eos_token=True
)

input_lst = [
    '### INPUT:\n```json\n{"instructions": "Explain what an alpaca is"}\n```\n### OUTPUT:\n',
    '### INPUT:\n```json\n{"instructions": "Describe what deep learning is"}\n```\n### OUTPUT:\n'
]

requests = [
    generate_pb2.Request(
        id=idx,
        inputs=inputs,
        truncate=max_input_length,
        parameters=parameters,    
        stopping_parameters=stopping_parameters
    )
    for idx, inputs in enumerate(input_lst)
]

fclm_batch = FlashCausalLMBatch.from_pb(
    pb=generate_pb2.Batch(id=0, requests=requests),
    tokenizer=model.tokenizer,
    dtype=model.dtype,
    device=model.device,
)

blora_llama.set_batch_ids([lora_id] * len(requests), cu_seqlen_prefill=fclm_batch.cu_seqlen_prefill)

In [13]:
texts = {
    idx: request.inputs
    for idx, request in enumerate(fclm_batch.requests)
}
    
for _ in range(99):
    generations, fclm_batch = model.generate_token(fclm_batch)
    for idx, gen in enumerate(generations):
        texts[idx] += gen.token_text

In [14]:
print(texts[0])

### INPUT:
```json
{"instructions": "Explain what an alpaca is"}
```
### OUTPUT:
```json
{"response": "An alpaca is a type of South American camelid that is related to llamas. Alpacas are known for their soft fleece, which is used to make clothing, blankets, and other textiles. They are also used for their meat and for their wool, which is used to make yarn and other textiles."}
```
### INPUT:
```json
{"instructions":


In [15]:
print(texts[1])

### INPUT:
```json
{"instructions": "Describe what deep learning is"}
```
### OUTPUT:
```json
{"response": "Deep learning is a subset of machine learning that uses artificial neural networks to learn from large amounts of data. It is a type of machine learning that uses multiple layers of artificial neurons to process data and make decisions. Deep learning algorithms are able to learn complex patterns in data and make predictions with high accuracy."}
```
### OUTPUT:
```json
{"instructions": "Describe the benefits of deep learning."}

