Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Style: change line length to 120 and enforce import sort order #383

Merged
merged 6 commits into from
Apr 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"editor.rulers": [
100
120
],
"editor.formatOnSave": true,
"[python]": {
Expand All @@ -9,11 +9,11 @@
},
"black-formatter.args": [
"--line-length",
"100"
"120"
],
"flake8.args": [
"--max-line-length=120",
"--ignore=E203"
"--ignore=E203,W503"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": false,
Expand Down
13 changes: 4 additions & 9 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,8 @@ def valid_input(cls, v):
@validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
if (
parameters is not None
and parameters.best_of is not None
and parameters.best_of > 1
and field_value
):
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value:
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
return field_value


Expand All @@ -236,6 +229,7 @@ class InputToken(BaseModel):
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]


# Alternative Tokens
class AlternativeToken(BaseModel):
# Token ID from the model tokenizer
Expand All @@ -245,6 +239,7 @@ class AlternativeToken(BaseModel):
# Logprob
logprob: float


# Generated tokens
class Token(BaseModel):
# Token ID from the model tokenizer
Expand Down
1 change: 1 addition & 0 deletions server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ export-requirements:
format:
pip install ruff
python -m ruff format lorax_server
python -m ruff check lorax_server --fix
4 changes: 1 addition & 3 deletions server/lorax_server/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def load_adapter_config(
if "medusa_num_heads" in config:
return MedusaConfig.load(config)

raise ValueError(
f"No valid adapter config file found: " f"tried {adapter_config_path} and {config_path}"
)
raise ValueError(f"No valid adapter config file found: " f"tried {adapter_config_path} and {config_path}")


__all__ = [
Expand Down
32 changes: 7 additions & 25 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,59 +181,41 @@ def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs

def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM for rank_data in self.rank_data.values()
)
return all(rank_data.rank // pg.size() <= MAX_RANK_CUSTOM for rank_data in self.rank_data.values())

@classmethod
def key(cls) -> str:
return LORA

@classmethod
def load(
self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata
) -> "BatchLoraWeights":
def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata) -> "BatchLoraWeights":
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)}

first_weights = list(adapter_weights.values())[0]
device = first_weights.weights_a.device
segment_indices = meta.segment_indices

lora_a = {
idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights
}
lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights}
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
)
(adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b = {
idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights
}
lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights}
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
)
(adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)

adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights
}

rank_indices = defaultdict(list)
Expand Down
18 changes: 4 additions & 14 deletions server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ class MedusaHead(torch.nn.Module):
def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config.medusa_num_layers)
]
[ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config.medusa_num_layers)]
)
n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
Expand All @@ -90,10 +87,7 @@ class MedusaModel(torch.nn.Module):
def __init__(self, config: MedusaConfig, weights: AbstractWeights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config.medusa_num_heads)
]
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config.medusa_num_heads)]
)

def forward(self, x):
Expand Down Expand Up @@ -141,14 +135,10 @@ def key(cls) -> str:
return MEDUSA

@classmethod
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata"
) -> "BatchMedusaWeights":
def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchMedusaWeights":
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, MedusaWeights)}

adapter_to_medusa = {
idx: adapter_weights[idx] for idx in meta.segment_indices if idx in adapter_weights
}
adapter_to_medusa = {idx: adapter_weights[idx] for idx in meta.segment_indices if idx in adapter_weights}

return BatchMedusaWeights(
adapter_to_medusa=adapter_to_medusa,
Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def download_adapter(
HfApi(token=api_token).model_info(adapter_id, revision=None)

# fail fast if ID is not an adapter (i.e. it is a full model)
source = get_model_source(
adapter_source, adapter_id, extension=".safetensors", api_token=api_token
)
source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token)
source.load_config()

download_weights(adapter_id, source=adapter_source, api_token=api_token)
Expand Down
22 changes: 5 additions & 17 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def key(cls) -> str:
pass

@abstractclassmethod
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata"
) -> "BatchAdapterWeights":
def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchAdapterWeights":
pass


Expand All @@ -67,18 +65,14 @@ def remove_adapter(self, adapter_idx: int):

@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens for adapter_weights in self.adapter_weights.values()
)
return max(adapter_weights.speculative_tokens for adapter_weights in self.adapter_weights.values())

def is_empty(self) -> bool:
return len(self.adapter_weights) == 0

def get_data(self, meta: AdapterBatchMetadata) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = (
defaultdict(dict)
)
adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
adapter_batch_types[adapter_weights.get_batch_type()][adapter_index] = adapter_weights

Expand All @@ -96,9 +90,7 @@ class AdapterBatchData:
data: Dict[str, Dict[str, BatchAdapterWeights]]

@staticmethod
def from_meta(
meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights]
) -> "AdapterBatchData":
def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights]) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
Expand All @@ -112,11 +104,7 @@ def ranks(self) -> Set[int]:
if lora_data is None:
return set()

return set(
rank_data.rank
for layer_data in self.data.values()
for rank_data in lora_data.rank_data.values()
)
return set(rank_data.rank for layer_data in self.data.values() for rank_data in lora_data.rank_data.values())

@property
def max_rank(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from typing import Dict, Optional, TypeVar

import torch

from lorax_server.models.types import Batch

B = TypeVar("B", bound=Batch)
Expand Down
21 changes: 7 additions & 14 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import sys
import typer

from enum import Enum
from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum

from lorax_server.utils.weights import download_weights as _download_weights
import typer
from loguru import logger

from lorax_server.utils.weights import download_weights as _download_weights

app = typer.Typer()

Expand Down Expand Up @@ -50,15 +49,9 @@ def serve(
):
if sharded:
assert os.getenv("RANK", None) is not None, "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
assert os.getenv("WORLD_SIZE", None) is not None, "WORLD_SIZE must be set when sharded is True"
assert os.getenv("MASTER_ADDR", None) is not None, "MASTER_ADDR must be set when sharded is True"
assert os.getenv("MASTER_PORT", None) is not None, "MASTER_PORT must be set when sharded is True"

# Remove default handler
logger.remove()
Expand Down
10 changes: 5 additions & 5 deletions server/lorax_server/interceptor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
import grpc
from typing import Any, Callable

from google.rpc import status_pb2, code_pb2
from grpc_status import rpc_status
import grpc
import torch
from google.rpc import code_pb2, status_pb2
from grpc_interceptor.server import AsyncServerInterceptor
from grpc_status import rpc_status
from loguru import logger
from typing import Callable, Any


class ExceptionInterceptor(AsyncServerInterceptor):
Expand Down
12 changes: 6 additions & 6 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import torch
from typing import Optional

import torch
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from typing import Optional

from lorax_server.models.model import Model
from lorax_server.models.bloom import BLOOMSharded
from lorax_server.models.causal_lm import CausalLM
from lorax_server.models.flash_causal_lm import FlashCausalLM
from lorax_server.models.bloom import BLOOMSharded
from lorax_server.models.galactica import GalacticaSharded
from lorax_server.models.model import Model
from lorax_server.models.mpt import MPTSharded
from lorax_server.models.seq2seq_lm import Seq2SeqLM
from lorax_server.models.opt import OPTSharded
from lorax_server.models.galactica import GalacticaSharded
from lorax_server.models.santacoder import SantaCoder
from lorax_server.models.seq2seq_lm import Seq2SeqLM
from lorax_server.models.t5 import T5Sharded
from lorax_server.utils.sources import get_s3_model_local_dir

Expand Down
18 changes: 7 additions & 11 deletions server/lorax_server/models/bloom.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch
import torch.distributed

from typing import Dict, List, Optional, Tuple, Type

import torch
import torch.distributed
from transformers import (
AutoTokenizer,
AutoConfig,
AutoTokenizer,
PreTrainedTokenizerBase,
)

from lorax_server.adapters import AdapterBatchData
from lorax_server.models.causal_lm import CausalLM, CausalLMBatch
from lorax_server.models.custom_modeling.bloom_modeling import (
ATTN_DENSE,
ATTN_QKV,
Expand All @@ -17,16 +18,13 @@
MLP_DENSE_H_TO_4H,
BloomForCausalLM,
)
from lorax_server.models import CausalLM
from lorax_server.models.causal_lm import CausalLMBatch
from lorax_server.pb import generate_pb2
from lorax_server.utils import (
Weights,
initialize_torch_distributed,
weight_files,
Weights,
)
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.adapters import AdapterBatchData

ADAPTER_LAYERS = [ATTN_QKV, ATTN_DENSE, MLP_DENSE_H_TO_4H, MLP_DENSE_4H_TO_H]
ROW_PARALLEL = {ATTN_DENSE, MLP_DENSE_4H_TO_H}
Expand All @@ -42,9 +40,7 @@ def from_pb(
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(
pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device
)
batch = super().from_pb(pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch

Expand Down
Loading
Loading