diff --git a/QEfficient/finetune/configs/training.py b/QEfficient/finetune/configs/training.py index 0257a63ed..bc51b2a31 100644 --- a/QEfficient/finetune/configs/training.py +++ b/QEfficient/finetune/configs/training.py @@ -10,7 +10,7 @@ @dataclass class train_config: model_name: str = "meta-llama/Llama-3.2-1B" - tokenizer_name: str = "meta-llama/Llama-3.2-1B" + tokenizer_name: str = None # if not passed as an argument, it uses the value of model_name run_validation: bool = True batch_size_training: int = 1 context_length: int = None diff --git a/QEfficient/finetune/utils/train_utils.py b/QEfficient/finetune/utils/train_utils.py index d8e2799f4..3867bd7b6 100644 --- a/QEfficient/finetune/utils/train_utils.py +++ b/QEfficient/finetune/utils/train_utils.py @@ -112,6 +112,15 @@ def train( f"Not proceeding with epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps." ) break + + if train_config.use_peft and train_config.from_peft_checkpoint: + intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1 + if epoch < intermediate_epoch: + print(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.") + # to bring the count of train_step in sync with where it left off + total_train_steps += len(train_dataloader) + continue + print(f"Starting epoch {epoch + 1}/{train_config.num_epochs}") print(f"train_config.max_train_step: {train_config.max_train_step}") # stop when the maximum number of training steps is reached @@ -131,8 +140,23 @@ def train( # enable profile for qaic qaic_profile.start_profiling(device, 1) if train_config.use_profiler else None + for step, batch in enumerate(train_dataloader): + # resume training from a particular checkpoint, assuming the dataset is not shuffled + if train_config.use_peft and train_config.from_peft_checkpoint: + intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1]) + intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1 + # to bring the count of train_step in sync with where it left off + if epoch == intermediate_epoch and step == 0: + total_train_steps += intermediate_step + print( + f"skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for them." + ) + if epoch == intermediate_epoch and step < intermediate_step: + total_train_steps += 1 + continue total_train_steps += 1 + # stop when the maximum number of training steps is reached if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step: max_steps_reached = True @@ -206,9 +230,11 @@ def train( qaic_profile.stop_profiling(device) if train_config.use_profiler else None if train_config.enable_ddp: if dist.get_rank() == 0: - model.module.save_pretrained(train_config.output_dir + f"/trained_weights/step_{step}") + model.module.save_pretrained( + train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step}" + ) else: - model.save_pretrained(train_config.output_dir + f"/trained_weights/step_{step}") + model.save_pretrained(train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step}") pbar.set_description( f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, step {step + 1}/{len(train_dataloader)} completed (loss: {loss.detach().float()})" @@ -243,9 +269,16 @@ def train( epoch_times.append(epoch_end_time) if loss_0_counter.item() == train_config.convergence_counter: - train_epoch_loss = total_loss / step + if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch: + train_epoch_loss = total_loss / (step - intermediate_step) + else: + train_epoch_loss = total_loss / step else: - train_epoch_loss = total_loss / len(train_dataloader) + if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch: + train_epoch_loss = total_loss / (len(train_dataloader) - intermediate_step) + else: + train_epoch_loss = total_loss / len(train_dataloader) + train_perplexity = torch.exp(train_epoch_loss) train_prep.append(float(train_perplexity)) @@ -253,7 +286,6 @@ def train( # Update the learning rate as needed lr_scheduler.step() - should_save_model = train_config.save_model if train_config.run_validation: if train_config.enable_ddp: @@ -275,14 +307,14 @@ def train( if train_config.save_metrics: val_step_loss.extend(temp_val_loss) val_step_perplexity.extend(temp_step_perplexity) - should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss - if should_save_model: + # saving the adapters after completion of each epoch + if train_config.save_model: if train_config.enable_ddp: if dist.get_rank() == 0: - model.module.save_pretrained(train_config.output_dir) + model.module.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}") else: - model.save_pretrained(train_config.output_dir) + model.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}") if train_config.run_validation: if eval_epoch_loss < best_val_loss: @@ -307,7 +339,6 @@ def train( val_step_perplexity, val_prep, ) - avg_epoch_time = sum(epoch_times) / len(epoch_times) avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0 avg_train_prep = sum(train_prep) / len(train_prep)