Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion gamesense/pipelines/train_accelerated.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

from steps import (
evaluate_model,
finetune_accelerated,
finetune,
log_metadata_from_step_artifact,
prepare_data,
promote,
)
from zenml import pipeline
from zenml.integrations.huggingface.steps import run_with_accelerate


@pipeline
Expand Down Expand Up @@ -75,6 +76,9 @@ def llm_peft_full_finetune(
id="log_metadata_evaluation_base",
)

finetune_accelerated = run_with_accelerate(
finetune, num_processes=2, multi_gpu=True, mixed_precision="bf16"
)
ft_model_dir = finetune_accelerated(
base_model_id=base_model_id,
dataset_dir=datasets_dir,
Expand Down
4 changes: 2 additions & 2 deletions gamesense/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@

\b
# Run the pipeline with custom config
python run.py --config custom_finetune.yaml
python run.py --config phi3.5_finetune_local.yaml
"""
)
@click.option(
"--config",
type=str,
default="default_finetune.yaml",
default="phi3.5_finetune_local.yaml",
help="Path to the YAML config file.",
)
@click.option(
Expand Down
2 changes: 1 addition & 1 deletion gamesense/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from .evaluate_model import evaluate_model
from .finetune import finetune, finetune_accelerated
from .finetune import finetune
from .log_metadata import log_metadata_from_step_artifact
from .prepare_datasets import prepare_data
from .promote import promote
7 changes: 2 additions & 5 deletions gamesense/steps/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from zenml.logger import get_logger
from zenml.materializers import BuiltInMaterializer
from zenml.utils.cuda_utils import cleanup_gpu_memory
from zenml.client import Client


logger = get_logger(__name__)

Expand Down Expand Up @@ -196,8 +198,3 @@ def finetune(
)

return ft_model_dir


finetune_accelerated = run_with_accelerate(
finetune, num_processes=2, multi_gpu=True, mixed_precision="bf16"
)
Loading