Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions dspy/clients/lm_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,25 +253,43 @@ def tokenize_function(example):
task_type="CAUSAL_LM",
)

sft_config = SFTConfig(
output_dir=train_kwargs["output_dir"],
num_train_epochs=train_kwargs["num_train_epochs"],
per_device_train_batch_size=train_kwargs["per_device_train_batch_size"],
gradient_accumulation_steps=train_kwargs["gradient_accumulation_steps"],
learning_rate=train_kwargs["learning_rate"],
max_grad_norm=2.0, # note that the current SFTConfig default is 1.0
logging_steps=20,
warmup_ratio=0.03,
lr_scheduler_type="constant",
save_steps=10_000,
bf16=train_kwargs["bf16"],
max_seq_length=train_kwargs["max_seq_length"],
packing=train_kwargs["packing"],
dataset_kwargs={ # We need to pass dataset_kwargs because we are processing the dataset ourselves
# Handle compatibility between different TRL versions
# TRL >= 0.16.0 uses 'max_length' instead of 'max_seq_length' in SFTConfig
import inspect
sft_config_params = inspect.signature(SFTConfig.__init__).parameters

# Build config parameters, handling the max_seq_length -> max_length change
config_kwargs = {
"output_dir": train_kwargs["output_dir"],
"num_train_epochs": train_kwargs["num_train_epochs"],
"per_device_train_batch_size": train_kwargs["per_device_train_batch_size"],
"gradient_accumulation_steps": train_kwargs["gradient_accumulation_steps"],
"learning_rate": train_kwargs["learning_rate"],
"max_grad_norm": 2.0, # note that the current SFTConfig default is 1.0
"logging_steps": 20,
"warmup_ratio": 0.03,
"lr_scheduler_type": "constant",
"save_steps": 10_000,
"bf16": train_kwargs["bf16"],
"packing": train_kwargs["packing"],
"dataset_kwargs": { # We need to pass dataset_kwargs because we are processing the dataset ourselves
"add_special_tokens": False, # Special tokens handled by template
"append_concat_token": False, # No additional separator needed
},
)
}

# Add the sequence length parameter using the appropriate name for the TRL version
if "max_seq_length" in sft_config_params:
# Older TRL versions (< 0.16.0)
config_kwargs["max_seq_length"] = train_kwargs["max_seq_length"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible that the user is setting max_length instead of max_seq_length, then max_length will go nowhere.

Instead of this silent transformation, can we raise an error on detecting mismatch between train_kwargs and SFTConfig on the name of max length? The error can be something like "max_seq_length is replaced by max_length in trl>=0.16.0, but you set ... on trl={the_detected_version}"

Copy link
Contributor Author

@fsndzomga fsndzomga Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, thank you for the feedback. After thinking about this, I’ve come to the conclusion that it’s probably better to just let the code fail as it currently does, and let the user either install the correct version of trl or update the current implementation in DSPy to work only with the latest trl API and fail otherwise. In both cases, we should specify which version of trl is recommended and perhaps include a check for it. Otherwise it becomes a maintenance mess if trl API changes again in the future ?

elif "max_length" in sft_config_params:
# Newer TRL versions (>= 0.16.0)
config_kwargs["max_length"] = train_kwargs["max_seq_length"]
else:
logger.warning("Neither 'max_seq_length' nor 'max_length' parameter found in SFTConfig. "
"This may indicate an incompatible TRL version.")

sft_config = SFTConfig(**config_kwargs)

logger.info("Starting training")
trainer = SFTTrainer(
Expand Down