diff --git a/server/lorax_server/utils/weights.py b/server/lorax_server/utils/weights.py index 9e0533ce..7ea1d400 100644 --- a/server/lorax_server/utils/weights.py +++ b/server/lorax_server/utils/weights.py @@ -242,14 +242,40 @@ def _get_gptq_params(self) -> Tuple[int, int]: return bits, groupsize def _set_gptq_params(self, model_id): + filename = "config.json" try: - filename = hf_hub_download(model_id, filename="quantize_config.json") + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename) with open(filename, "r") as f: data = json.load(f) - self.gptq_bits = data["bits"] - self.gptq_groupsize = data["group_size"] + self.gptq_bits = data["quantization_config"]["bits"] + self.gptq_groupsize = data["quantization_config"]["group_size"] except Exception: - pass + filename = "quantize_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename) + with open(filename, "r") as f: + data = json.load(f) + self.gptq_bits = data["bits"] + self.gptq_groupsize = data["group_size"] + except Exception: + filename = "quant_config.json" + try: + if os.path.exists(os.path.join(model_id, filename)): + filename = os.path.join(model_id, filename) + else: + filename = hf_hub_download(model_id, filename=filename) + with open(filename, "r") as f: + data = json.load(f) + self.gptq_bits = data["w_bit"] + self.gptq_groupsize = data["q_group_size"] + except Exception: + pass def get_start_stop_idxs_for_rank(size, rank, world_size): block_size = size // world_size @@ -272,4 +298,4 @@ def shard_on_dim(t: torch.Tensor, dim: int, process_group: torch.distributed.Pro else: raise NotImplementedError("Let's make that generic when needed") - return tensor \ No newline at end of file + return tensor