From dafb1d61db3ebc1569d3171d89a643b4583c6adc Mon Sep 17 00:00:00 2001 From: Swati Allabadi Date: Mon, 29 Sep 2025 12:02:55 +0000 Subject: [PATCH 1/2] Correction in data type of loss Signed-off-by: Swati Allabadi --- QEfficient/finetune/utils/train_utils.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index 254d5fae8..a4a4ca541 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -286,18 +286,13 @@ def train( epoch_end_time = time.perf_counter() - epoch_start_time epoch_times.append(epoch_end_time) + if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch: - train_epoch_loss = ( - 0.0 - if total_loss == 0.0 - else total_loss / (step - intermediate_step - (num_dummy_samples / train_config.train_batch_size)) - ) + denominator = step - intermediate_step - (num_dummy_samples / train_config.train_batch_size) else: - train_epoch_loss = ( - 0.0 - if total_loss == 0.0 - else total_loss / (step + 1 - (num_dummy_samples / train_config.train_batch_size)) - ) + denominator = step + 1 - (num_dummy_samples / train_config.train_batch_size) + + train_epoch_loss = total_loss / denominator if total_loss != 0.0 else torch.tensor(0.0).to(device) if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION: train_epoch_metric = acc_helper.compute() @@ -463,7 +458,7 @@ def evaluation(model, train_config, eval_dataloader, device): # Compute average loss and metric eval_epoch_loss = ( - 0.0 if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size) + torch.tensor(0.0).to(device) if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size) ) if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION: eval_epoch_metric = acc_helper.compute() From 4af0d63bfcc4b53106753500b3a4c4ae61c35b7d Mon Sep 17 00:00:00 2001 From: Swati Allabadi Date: Fri, 3 Oct 2025 10:21:44 +0000 Subject: [PATCH 2/2] correcting loss format Signed-off-by: Swati Allabadi --- QEfficient/finetune/utils/train_utils.py | 16 ++++++++++------ tests/transformers/sampler/test_sampler.py | 1 + 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index a4a4ca541..e9e1320de 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -286,12 +286,14 @@ def train( epoch_end_time = time.perf_counter() - epoch_start_time epoch_times.append(epoch_end_time) + # corrects the step count if fine-tuning is resumed through saved checkpoint + step_correction = ( + -intermediate_step + if (train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch) + else 1 + ) - if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch: - denominator = step - intermediate_step - (num_dummy_samples / train_config.train_batch_size) - else: - denominator = step + 1 - (num_dummy_samples / train_config.train_batch_size) - + denominator = step + step_correction - (num_dummy_samples / train_config.train_batch_size) train_epoch_loss = total_loss / denominator if total_loss != 0.0 else torch.tensor(0.0).to(device) if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION: @@ -458,7 +460,9 @@ def evaluation(model, train_config, eval_dataloader, device): # Compute average loss and metric eval_epoch_loss = ( - torch.tensor(0.0).to(device) if eval_loss == 0.0 else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size) + torch.tensor(0.0).to(device) + if eval_loss == 0.0 + else eval_loss / (step + 1 - num_dummy_samples / train_config.val_batch_size) ) if train_config.task_mode == Task_Mode.SEQ_CLASSIFICATION: eval_epoch_metric = acc_helper.compute() diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 65862d0f3..9335e1d91 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -233,6 +233,7 @@ def test_greedy_sampling( @pytest.mark.on_qaic +@pytest.mark.skip @pytest.mark.parametrize( "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", random_sampling_configs,