Skip to content

Commit

Permalink
Style: change line length to 120 and enforce import sort order (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
tgaddair committed Apr 4, 2024
1 parent e729652 commit 6317e29
Show file tree
Hide file tree
Showing 84 changed files with 735 additions and 1,564 deletions.
6 changes: 3 additions & 3 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"editor.rulers": [
100
120
],
"editor.formatOnSave": true,
"[python]": {
Expand All @@ -9,11 +9,11 @@
},
"black-formatter.args": [
"--line-length",
"100"
"120"
],
"flake8.args": [
"--max-line-length=120",
"--ignore=E203"
"--ignore=E203,W503"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": false,
Expand Down
13 changes: 4 additions & 9 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,8 @@ def valid_input(cls, v):
@validator("stream")
def valid_best_of_stream(cls, field_value, values):
parameters = values["parameters"]
if (
parameters is not None
and parameters.best_of is not None
and parameters.best_of > 1
and field_value
):
raise ValidationError(
"`best_of` != 1 is not supported when `stream` == True"
)
if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value:
raise ValidationError("`best_of` != 1 is not supported when `stream` == True")
return field_value


Expand All @@ -236,6 +229,7 @@ class InputToken(BaseModel):
# Optional since the logprob of the first token cannot be computed
logprob: Optional[float]


# Alternative Tokens
class AlternativeToken(BaseModel):
# Token ID from the model tokenizer
Expand All @@ -245,6 +239,7 @@ class AlternativeToken(BaseModel):
# Logprob
logprob: float


# Generated tokens
class Token(BaseModel):
# Token ID from the model tokenizer
Expand Down
1 change: 1 addition & 0 deletions server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ export-requirements:
format:
pip install ruff
python -m ruff format lorax_server
python -m ruff check lorax_server --fix
4 changes: 1 addition & 3 deletions server/lorax_server/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ def load_adapter_config(
if "medusa_num_heads" in config:
return MedusaConfig.load(config)

raise ValueError(
f"No valid adapter config file found: " f"tried {adapter_config_path} and {config_path}"
)
raise ValueError(f"No valid adapter config file found: " f"tried {adapter_config_path} and {config_path}")


__all__ = [
Expand Down
32 changes: 7 additions & 25 deletions server/lorax_server/adapters/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,59 +181,41 @@ def has_adapter(self, adapter_index: int) -> bool:
return adapter_index in self.adapter_index_configs

def can_vectorize(self, pg: ProcessGroup) -> bool:
return all(
rank_data.rank // pg.size() <= MAX_RANK_CUSTOM for rank_data in self.rank_data.values()
)
return all(rank_data.rank // pg.size() <= MAX_RANK_CUSTOM for rank_data in self.rank_data.values())

@classmethod
def key(cls) -> str:
return LORA

@classmethod
def load(
self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata
) -> "BatchLoraWeights":
def load(self, adapter_weights: Dict[int, AdapterWeights], meta: AdapterBatchMetadata) -> "BatchLoraWeights":
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, LoraWeights)}

first_weights = list(adapter_weights.values())[0]
device = first_weights.weights_a.device
segment_indices = meta.segment_indices

lora_a = {
idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights
}
lora_a = {idx: adapter_weights[idx].weights_a for idx in segment_indices if idx in adapter_weights}
lora_a_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_a.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
)
(adapter_weights[idx].weights_a.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)
lora_b = {
idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights
}
lora_b = {idx: adapter_weights[idx].weights_b for idx in segment_indices if idx in adapter_weights}
lora_b_ptr = torch.tensor(
[
(
adapter_weights[idx].weights_b.data_ptr()
if idx in adapter_weights
else EMPTY_TENSOR.data_ptr()
)
(adapter_weights[idx].weights_b.data_ptr() if idx in adapter_weights else EMPTY_TENSOR.data_ptr())
for idx in segment_indices
],
dtype=torch.int64,
device=device,
)

adapter_index_configs = {
idx: adapter_weights[idx].adapter_config
for idx in segment_indices
if idx in adapter_weights
idx: adapter_weights[idx].adapter_config for idx in segment_indices if idx in adapter_weights
}

rank_indices = defaultdict(list)
Expand Down
18 changes: 4 additions & 14 deletions server/lorax_server/adapters/medusa.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,7 @@ class MedusaHead(torch.nn.Module):
def __init__(self, config: MedusaConfig, prefix: str, weights: AbstractWeights):
super().__init__()
self.blocks = torch.nn.ModuleList(
[
ResBlock(config, prefix=f"{prefix}.{i}", weights=weights)
for i in range(config.medusa_num_layers)
]
[ResBlock(config, prefix=f"{prefix}.{i}", weights=weights) for i in range(config.medusa_num_layers)]
)
n = len(self.blocks)
self.out = FastLinear.load(config, prefix=f"{prefix}.{n}", weights=weights, bias=False)
Expand All @@ -90,10 +87,7 @@ class MedusaModel(torch.nn.Module):
def __init__(self, config: MedusaConfig, weights: AbstractWeights):
super().__init__()
self.heads = torch.nn.ModuleList(
[
MedusaHead(config, prefix=f"{i}", weights=weights)
for i in range(config.medusa_num_heads)
]
[MedusaHead(config, prefix=f"{i}", weights=weights) for i in range(config.medusa_num_heads)]
)

def forward(self, x):
Expand Down Expand Up @@ -141,14 +135,10 @@ def key(cls) -> str:
return MEDUSA

@classmethod
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata"
) -> "BatchMedusaWeights":
def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchMedusaWeights":
adapter_weights = {k: v for k, v in adapter_weights.items() if isinstance(v, MedusaWeights)}

adapter_to_medusa = {
idx: adapter_weights[idx] for idx in meta.segment_indices if idx in adapter_weights
}
adapter_to_medusa = {idx: adapter_weights[idx] for idx in meta.segment_indices if idx in adapter_weights}

return BatchMedusaWeights(
adapter_to_medusa=adapter_to_medusa,
Expand Down
4 changes: 1 addition & 3 deletions server/lorax_server/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ def download_adapter(
HfApi(token=api_token).model_info(adapter_id, revision=None)

# fail fast if ID is not an adapter (i.e. it is a full model)
source = get_model_source(
adapter_source, adapter_id, extension=".safetensors", api_token=api_token
)
source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token)
source.load_config()

download_weights(adapter_id, source=adapter_source, api_token=api_token)
Expand Down
22 changes: 5 additions & 17 deletions server/lorax_server/adapters/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ def key(cls) -> str:
pass

@abstractclassmethod
def load(
cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata"
) -> "BatchAdapterWeights":
def load(cls, adapter_weights: Dict[int, AdapterWeights], meta: "AdapterBatchMetadata") -> "BatchAdapterWeights":
pass


Expand All @@ -67,18 +65,14 @@ def remove_adapter(self, adapter_idx: int):

@property
def max_speculative_tokens(self) -> int:
return max(
adapter_weights.speculative_tokens for adapter_weights in self.adapter_weights.values()
)
return max(adapter_weights.speculative_tokens for adapter_weights in self.adapter_weights.values())

def is_empty(self) -> bool:
return len(self.adapter_weights) == 0

def get_data(self, meta: AdapterBatchMetadata) -> Dict[str, BatchAdapterWeights]:
# bucket adapters by batch class
adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = (
defaultdict(dict)
)
adapter_batch_types: Dict[Type[BatchAdapterWeights], Dict[int, AdapterWeights]] = defaultdict(dict)
for adapter_index, adapter_weights in self.adapter_weights.items():
adapter_batch_types[adapter_weights.get_batch_type()][adapter_index] = adapter_weights

Expand All @@ -96,9 +90,7 @@ class AdapterBatchData:
data: Dict[str, Dict[str, BatchAdapterWeights]]

@staticmethod
def from_meta(
meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights]
) -> "AdapterBatchData":
def from_meta(meta: AdapterBatchMetadata, weights: Dict[str, LayerAdapterWeights]) -> "AdapterBatchData":
data = {}
for k, v in weights.items():
if v.is_empty():
Expand All @@ -112,11 +104,7 @@ def ranks(self) -> Set[int]:
if lora_data is None:
return set()

return set(
rank_data.rank
for layer_data in self.data.values()
for rank_data in lora_data.rank_data.values()
)
return set(rank_data.rank for layer_data in self.data.values() for rank_data in lora_data.rank_data.values())

@property
def max_rank(self) -> int:
Expand Down
4 changes: 2 additions & 2 deletions server/lorax_server/cache.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch

from typing import Dict, Optional, TypeVar

import torch

from lorax_server.models.types import Batch

B = TypeVar("B", bound=Batch)
Expand Down
21 changes: 7 additions & 14 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import os
import sys
import typer

from enum import Enum
from pathlib import Path
from loguru import logger
from typing import Optional
from enum import Enum

from lorax_server.utils.weights import download_weights as _download_weights
import typer
from loguru import logger

from lorax_server.utils.weights import download_weights as _download_weights

app = typer.Typer()

Expand Down Expand Up @@ -50,15 +49,9 @@ def serve(
):
if sharded:
assert os.getenv("RANK", None) is not None, "RANK must be set when sharded is True"
assert (
os.getenv("WORLD_SIZE", None) is not None
), "WORLD_SIZE must be set when sharded is True"
assert (
os.getenv("MASTER_ADDR", None) is not None
), "MASTER_ADDR must be set when sharded is True"
assert (
os.getenv("MASTER_PORT", None) is not None
), "MASTER_PORT must be set when sharded is True"
assert os.getenv("WORLD_SIZE", None) is not None, "WORLD_SIZE must be set when sharded is True"
assert os.getenv("MASTER_ADDR", None) is not None, "MASTER_ADDR must be set when sharded is True"
assert os.getenv("MASTER_PORT", None) is not None, "MASTER_PORT must be set when sharded is True"

# Remove default handler
logger.remove()
Expand Down
10 changes: 5 additions & 5 deletions server/lorax_server/interceptor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
import grpc
from typing import Any, Callable

from google.rpc import status_pb2, code_pb2
from grpc_status import rpc_status
import grpc
import torch
from google.rpc import code_pb2, status_pb2
from grpc_interceptor.server import AsyncServerInterceptor
from grpc_status import rpc_status
from loguru import logger
from typing import Callable, Any


class ExceptionInterceptor(AsyncServerInterceptor):
Expand Down
12 changes: 6 additions & 6 deletions server/lorax_server/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import torch
from typing import Optional

import torch
from loguru import logger
from transformers.configuration_utils import PretrainedConfig
from typing import Optional

from lorax_server.models.model import Model
from lorax_server.models.bloom import BLOOMSharded
from lorax_server.models.causal_lm import CausalLM
from lorax_server.models.flash_causal_lm import FlashCausalLM
from lorax_server.models.bloom import BLOOMSharded
from lorax_server.models.galactica import GalacticaSharded
from lorax_server.models.model import Model
from lorax_server.models.mpt import MPTSharded
from lorax_server.models.seq2seq_lm import Seq2SeqLM
from lorax_server.models.opt import OPTSharded
from lorax_server.models.galactica import GalacticaSharded
from lorax_server.models.santacoder import SantaCoder
from lorax_server.models.seq2seq_lm import Seq2SeqLM
from lorax_server.models.t5 import T5Sharded
from lorax_server.utils.sources import get_s3_model_local_dir

Expand Down
18 changes: 7 additions & 11 deletions server/lorax_server/models/bloom.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch
import torch.distributed

from typing import Dict, List, Optional, Tuple, Type

import torch
import torch.distributed
from transformers import (
AutoTokenizer,
AutoConfig,
AutoTokenizer,
PreTrainedTokenizerBase,
)

from lorax_server.adapters import AdapterBatchData
from lorax_server.models.causal_lm import CausalLM, CausalLMBatch
from lorax_server.models.custom_modeling.bloom_modeling import (
ATTN_DENSE,
ATTN_QKV,
Expand All @@ -17,16 +18,13 @@
MLP_DENSE_H_TO_4H,
BloomForCausalLM,
)
from lorax_server.models import CausalLM
from lorax_server.models.causal_lm import CausalLMBatch
from lorax_server.pb import generate_pb2
from lorax_server.utils import (
Weights,
initialize_torch_distributed,
weight_files,
Weights,
)
from lorax_server.utils.tokenizer import TokenizerManager
from lorax_server.adapters import AdapterBatchData

ADAPTER_LAYERS = [ATTN_QKV, ATTN_DENSE, MLP_DENSE_H_TO_4H, MLP_DENSE_4H_TO_H]
ROW_PARALLEL = {ATTN_DENSE, MLP_DENSE_4H_TO_H}
Expand All @@ -42,9 +40,7 @@ def from_pb(
dtype: torch.dtype,
device: torch.device,
) -> "CausalLMBatch":
batch = super().from_pb(
pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device
)
batch = super().from_pb(pb=pb, tokenizer=tokenizer, tokenizers=tokenizers, dtype=dtype, device=device)
batch.keys_head_dim_last = False
return batch

Expand Down
Loading

0 comments on commit 6317e29

Please sign in to comment.