diff --git a/pyproject.toml b/pyproject.toml index ad90d2b..4a6267e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "xturing" -version = "0.1.7" +version = "0.1.8" description = "Fine-tuning, evaluation and data generation for LLMs" authors = [ @@ -43,12 +43,12 @@ keywords = [ dependencies = [ "torch >= 1.9.0", "pytorch-lightning", - "transformers==4.28.1", - "datasets", - "evaluate", - "bitsandbytes==0.37.2", + "transformers==4.31.0", + "datasets==2.14.5", + "evaluate==0.4.0", + "bitsandbytes==0.41.1", "sentencepiece", - "deepspeed", + "deepspeed==0.9.5", "gradio", "click", "wget", @@ -58,7 +58,7 @@ dependencies = [ "openai >= 0.27.0", "pydantic >= 1.10.0", "rouge-score >= 0.1.2", - "accelerate", + "accelerate==0.22.0", "wandb", ] diff --git a/src/xturing/__about__.py b/src/xturing/__about__.py index f1380ee..9cb17e7 100644 --- a/src/xturing/__about__.py +++ b/src/xturing/__about__.py @@ -1 +1 @@ -__version__ = "0.1.7" +__version__ = "0.1.8" diff --git a/src/xturing/config/finetuning_config.yaml b/src/xturing/config/finetuning_config.yaml index 3f12670..37b82ed 100644 --- a/src/xturing/config/finetuning_config.yaml +++ b/src/xturing/config/finetuning_config.yaml @@ -32,6 +32,13 @@ bloom_lora_int8: batch_size: 8 max_length: 256 +bloom_int8: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 8 + max_length: 256 + cerebras: learning_rate: 5e-5 weight_decay: 0.01 @@ -50,6 +57,13 @@ cerebras_lora_int8: batch_size: 8 max_length: 256 +cerebras_int8: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 8 + max_length: 256 + distilgpt2: learning_rate: 1e-3 weight_decay: 0.01 @@ -115,6 +129,13 @@ galactica_lora_int8: batch_size: 8 max_length: 256 +galactica_int8: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 8 + max_length: 256 + generic: learning_rate: 1e-4 weight_decay: 0.01 @@ -169,6 +190,13 @@ gptj_lora_int8: batch_size: 8 max_length: 256 +gptj_int8: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 8 + max_length: 256 + gpt2: learning_rate: 1e-3 weight_decay: 0.01 @@ -187,13 +215,18 @@ gpt2_lora_int8: num_train_epochs: 3 batch_size: 16 +gpt2_int8: + learning_rate: 3e-3 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 16 + llama: learning_rate: 5e-5 weight_decay: 0.01 num_train_epochs: 3 optimizer_name: cpu_adam - llama_lora: learning_rate: 1e-4 weight_decay: 0.01 @@ -207,6 +240,13 @@ llama_lora_int8: batch_size: 8 max_length: 256 +llama_int8: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 8 + max_length: 256 + llama_lora_kbit: learning_rate: 3e-4 num_train_epochs: 3 @@ -275,3 +315,10 @@ opt_lora_int8: num_train_epochs: 3 batch_size: 8 max_length: 256 + +opt_int8: + learning_rate: 1e-4 + weight_decay: 0.01 + num_train_epochs: 3 + batch_size: 8 + max_length: 256 diff --git a/src/xturing/config/generation_config.yaml b/src/xturing/config/generation_config.yaml index f844cf0..2eba241 100644 --- a/src/xturing/config/generation_config.yaml +++ b/src/xturing/config/generation_config.yaml @@ -25,6 +25,11 @@ bloom_lora_int8: max_new_tokens: 256 do_sample: false +# Greedy search +bloom_int8: + max_new_tokens: 256 + do_sample: false + # Contrastive search cerebras: penalty_alpha: 0.6 @@ -44,6 +49,11 @@ cerebras_lora_int8: max_new_tokens: 256 do_sample: false +# Greedy search +cerebras_int8: + max_new_tokens: 256 + do_sample: false + # Top-p sampling distilgpt2: do_sample: true @@ -102,6 +112,11 @@ galactica_lora_int8: max_new_tokens: 256 do_sample: false +# Greedy search +galactica_int8: + max_new_tokens: 256 + do_sample: false + # Greedy search generic: max_new_tokens: 256 @@ -146,6 +161,11 @@ gptj_lora_int8: max_new_tokens: 256 do_sample: false +# Greedy search +gptj_int8: + max_new_tokens: 256 + do_sample: false + # Top-p sampling gpt2: do_sample: true @@ -167,6 +187,13 @@ gpt2_lora_int8: top_p: 0.92 max_new_tokens: 256 +# Top-p sampling +gpt2_int8: + do_sample: true + top_k: 0 + top_p: 0.92 + max_new_tokens: 256 + # Contrastive search llama: penalty_alpha: 0.6 @@ -186,6 +213,11 @@ llama_lora_int8: max_new_tokens: 256 do_sample: false +# Greedy search +llama_int8: + max_new_tokens: 256 + do_sample: false + # Greedy search llama_lora_kbit: max_new_tokens: 256 @@ -238,3 +270,8 @@ opt_lora: opt_lora_int8: max_new_tokens: 256 do_sample: false + +# Greedy search +opt_int8: + max_new_tokens: 256 + do_sample: false diff --git a/src/xturing/engines/__init__.py b/src/xturing/engines/__init__.py index 701db90..7422985 100644 --- a/src/xturing/engines/__init__.py +++ b/src/xturing/engines/__init__.py @@ -44,7 +44,13 @@ GPTJLoraEngine, GPTJLoraInt8Engine, ) -from xturing.engines.llama2_engine import LLama2Engine +from xturing.engines.llama2_engine import ( + LLama2Engine, + LLama2Int8Engine, + LLama2LoraEngine, + LLama2LoraInt8Engine, + LLama2LoraKbitEngine, +) from xturing.engines.llama_engine import ( LLamaEngine, LLamaInt8Engine, @@ -97,6 +103,10 @@ BaseEngine.add_to_registry(LlamaLoraInt8Engine.config_name, LlamaLoraInt8Engine) BaseEngine.add_to_registry(LlamaLoraKbitEngine.config_name, LlamaLoraKbitEngine) BaseEngine.add_to_registry(LLama2Engine.config_name, LLama2Engine) +BaseEngine.add_to_registry(LLama2Int8Engine.config_name, LLama2Int8Engine) +BaseEngine.add_to_registry(LLama2LoraEngine.config_name, LLama2LoraEngine) +BaseEngine.add_to_registry(LLama2LoraInt8Engine.config_name, LLama2LoraInt8Engine) +BaseEngine.add_to_registry(LLama2LoraKbitEngine.config_name, LLama2LoraKbitEngine) BaseEngine.add_to_registry(OPTEngine.config_name, OPTEngine) BaseEngine.add_to_registry(OPTInt8Engine.config_name, OPTInt8Engine) BaseEngine.add_to_registry(OPTLoraEngine.config_name, OPTLoraEngine) diff --git a/src/xturing/engines/generic_engine.py b/src/xturing/engines/generic_engine.py index e8cc813..b36dde4 100644 --- a/src/xturing/engines/generic_engine.py +++ b/src/xturing/engines/generic_engine.py @@ -64,7 +64,7 @@ def __init__( class GenericLoraKbitEngine(CausalLoraKbitEngine): - config_name: str = "generic+lora_kbit_engine" + config_name: str = "generic_lora_kbit_engine" def __init__( self, @@ -75,7 +75,6 @@ def __init__( super().__init__( model_name=model_name, weights_path=weights_path, - load_4bit=True, target_modules=target_modules, ) diff --git a/src/xturing/models/__init__.py b/src/xturing/models/__init__.py index 5be4f12..95be19c 100644 --- a/src/xturing/models/__init__.py +++ b/src/xturing/models/__init__.py @@ -36,7 +36,13 @@ LlamaLoraInt8, LlamaLoraKbit, ) -from xturing.models.llama2 import Llama2 +from xturing.models.llama2 import ( + Llama2, + Llama2Int8, + Llama2Lora, + Llama2LoraInt8, + Llama2LoraKbit, +) from xturing.models.opt import OPT, OPTInt8, OPTLora, OPTLoraInt8 from xturing.models.stable_diffusion import StableDiffusion @@ -78,6 +84,10 @@ BaseModel.add_to_registry(LlamaLoraInt8.config_name, LlamaLoraInt8) BaseModel.add_to_registry(LlamaLoraKbit.config_name, LlamaLoraKbit) BaseModel.add_to_registry(Llama2.config_name, Llama2) +BaseModel.add_to_registry(Llama2Int8.config_name, Llama2Int8) +BaseModel.add_to_registry(Llama2Lora.config_name, Llama2Lora) +BaseModel.add_to_registry(Llama2LoraInt8.config_name, Llama2LoraInt8) +BaseModel.add_to_registry(Llama2LoraKbit.config_name, Llama2LoraKbit) BaseModel.add_to_registry(OPT.config_name, OPT) BaseModel.add_to_registry(OPTInt8.config_name, OPTInt8) BaseModel.add_to_registry(OPTLora.config_name, OPTLora) diff --git a/src/xturing/models/causal.py b/src/xturing/models/causal.py index bfa87f3..62bb274 100644 --- a/src/xturing/models/causal.py +++ b/src/xturing/models/causal.py @@ -1,10 +1,8 @@ import json from pathlib import Path - -from typing import Iterable, List, Optional, Tuple, Type, Union +from typing import Iterable, List, Optional, Tuple, Union import torch -import torch.nn.functional as F from pytorch_lightning.loggers import Logger from torch.utils.data import DataLoader from tqdm import tqdm @@ -21,15 +19,7 @@ from xturing.trainers.base import BaseTrainer from xturing.trainers.lightning_trainer import LightningTrainer from xturing.utils.logging import configure_logger -from xturing.utils.metrics import get_accuracy -from xturing.utils.prompt import ( - OpenAIChatMessage, - OpenAICreateChatPrompt, - OpenAICreatePrompt, - Prompt, - chat_prompt_to_text, - is_chat_prompt, -) +from xturing.utils.prompt import OpenAICreateChatPrompt, OpenAICreatePrompt, Prompt from xturing.utils.utils import _filter_args, _index_samples TokenSequence = Union[List[int], torch.LongTensor, torch.Tensor, BatchEncoding] @@ -44,6 +34,7 @@ def __init__( weights_path: Optional[str] = None, model_name: Optional[str] = None, target_modules: Optional[List[str]] = None, + transfer_to_device: Optional[bool] = True, **kwargs, ): arguments = dict( @@ -82,6 +73,8 @@ def __init__( logger.debug(f"Finetuning parameters: {self.finetuning_args}") logger.debug(f"Generation parameters: {self.generation_args}") + self.transfer_to_device = transfer_to_device + def finetuning_config(self): return self.finetuning_args @@ -163,7 +156,9 @@ def generate( batch_size: Optional[int] = 1, ): self.engine.model.eval() - self.engine.model = self.engine.model.to(DEFAULT_DEVICE) + + if self.transfer_to_device: + self.engine.model = self.engine.model.to(DEFAULT_DEVICE) outputs = [] @@ -239,18 +234,9 @@ def _model_call( def completion_query( self, prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt] ): - # actual_prompt = chat_prompt_to_text(prompt) actual_prompt = prompt logger.info(prompt) text_out = self.generate(texts=[actual_prompt]) - - # parse results - # result = { - # "text": text_out, - # "tokens": None, - # "logprobs": None, - # } - return text_out, actual_prompt def check_sampled_text( @@ -314,8 +300,6 @@ def evaluate( dataset: Union[TextDataset, InstructionDataset], batch_size: Optional[int] = 1, ): - # outputs = self.eval_all_samples(dataset) - # return get_accuracy(outputs) collate_fn = self._make_collate_fn(dataset) dataloader = DataLoader( dataset, @@ -338,7 +322,11 @@ def __init__( ): assert_not_cpu_int8() super().__init__( - engine, weights_path=weights_path, model_name=model_name, **kwargs + engine, + weights_path=weights_path, + model_name=model_name, + transfer_to_device=False, + **kwargs, ) @@ -400,18 +388,19 @@ def __init__( class CausalLoraKbitModel(CausalLoraModel): def __init__( - self, - engine: str, - weights_path: Optional[str] = None, - model_name: Optional[str] = None, - target_modules: Optional[List[str]] = None, - **kwargs, - ): + self, + engine: str, + weights_path: Optional[str] = None, + model_name: Optional[str] = None, + target_modules: Optional[List[str]] = None, + **kwargs, + ): assert_not_cpu_int8() super().__init__( engine, weights_path=weights_path, model_name=model_name, target_modules=target_modules, + transfer_to_device=False, **kwargs, ) diff --git a/tests/xturing/models/test_generic_models.py b/tests/xturing/models/test_generic_models.py new file mode 100644 index 0000000..9c2c7b7 --- /dev/null +++ b/tests/xturing/models/test_generic_models.py @@ -0,0 +1,55 @@ +import tempfile +from pathlib import Path + +from xturing.models import ( + GenericInt8Model, + GenericLoraInt8Model, + GenericLoraKbitModel, + GenericLoraModel, + GenericModel, +) + + +def test_generic_model(): + saving_path = Path(tempfile.gettempdir()) / "test_xturing_generic" + model = GenericModel("distilgpt2") + model.save(str(saving_path)) + + model2 = GenericModel(str(saving_path)) + model2.generate(texts=["Why are the LLM so important?"]) + + +def test_generic_model_int8(): + saving_path = Path(tempfile.gettempdir()) / "test_xturing_generic_int8" + model = GenericInt8Model("distilgpt2") + model.save(str(saving_path)) + + model2 = GenericInt8Model(str(saving_path)) + model2.generate(texts=["Why are the LLM so important?"]) + + +def test_generic_model_lora(): + saving_path = Path(tempfile.gettempdir()) / "test_xturing_generic_lora" + model = GenericLoraModel("distilgpt2") + model.save(str(saving_path)) + + model2 = GenericLoraModel(str(saving_path)) + model2.generate(texts=["Why are the LLM so important?"]) + + +def test_generic_model_int8_lora(): + saving_path = Path(tempfile.gettempdir()) / "test_xturing_lora_int8" + model = GenericLoraInt8Model("distilgpt2") + model.save(str(saving_path)) + + model2 = GenericLoraInt8Model(str(saving_path)) + model2.generate(texts=["Why are the LLM so important?"]) + + +def test_generic_model_lora_kbit(): + saving_path = Path(tempfile.gettempdir()) / "test_xturing_lora_kbit" + model = GenericLoraKbitModel("distilgpt2") + model.save(str(saving_path)) + + model2 = GenericLoraKbitModel(str(saving_path)) + model2.generate(texts=["Why are the LLM so important?"])