Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion QEfficient/finetune/configs/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@dataclass
class train_config:
model_name: str = "meta-llama/Llama-3.2-1B"
tokenizer_name: str = "meta-llama/Llama-3.2-1B"
tokenizer_name: str = None # if not passed as an argument, it uses the value of model_name
run_validation: bool = True
batch_size_training: int = 1
context_length: int = None
Expand Down
51 changes: 41 additions & 10 deletions QEfficient/finetune/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ def train(
f"Not proceeding with epoch {epoch + 1} since loss value has been <= {train_config.convergence_loss} for last {loss_0_counter.item()} steps."
)
break

if train_config.use_peft and train_config.from_peft_checkpoint:
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
if epoch < intermediate_epoch:
print(f"Skipping epoch {epoch + 1} since fine tuning has already completed for it.")
# to bring the count of train_step in sync with where it left off
total_train_steps += len(train_dataloader)
continue

print(f"Starting epoch {epoch + 1}/{train_config.num_epochs}")
print(f"train_config.max_train_step: {train_config.max_train_step}")
# stop when the maximum number of training steps is reached
Expand All @@ -131,8 +140,23 @@ def train(

# enable profile for qaic
qaic_profile.start_profiling(device, 1) if train_config.use_profiler else None

for step, batch in enumerate(train_dataloader):
# resume training from a particular checkpoint, assuming the dataset is not shuffled
if train_config.use_peft and train_config.from_peft_checkpoint:
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
# to bring the count of train_step in sync with where it left off
if epoch == intermediate_epoch and step == 0:
total_train_steps += intermediate_step
print(
f"skipping first {intermediate_step} steps for epoch {epoch + 1}, since fine tuning has already completed for them."
)
if epoch == intermediate_epoch and step < intermediate_step:
total_train_steps += 1
continue
total_train_steps += 1

# stop when the maximum number of training steps is reached
if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
max_steps_reached = True
Expand Down Expand Up @@ -206,9 +230,11 @@ def train(
qaic_profile.stop_profiling(device) if train_config.use_profiler else None
if train_config.enable_ddp:
if dist.get_rank() == 0:
model.module.save_pretrained(train_config.output_dir + f"/trained_weights/step_{step}")
model.module.save_pretrained(
train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step}"
)
else:
model.save_pretrained(train_config.output_dir + f"/trained_weights/step_{step}")
model.save_pretrained(train_config.output_dir + f"/trained_weights/epoch_{epoch + 1}/step_{step}")

pbar.set_description(
f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, step {step + 1}/{len(train_dataloader)} completed (loss: {loss.detach().float()})"
Expand Down Expand Up @@ -243,17 +269,23 @@ def train(
epoch_times.append(epoch_end_time)

if loss_0_counter.item() == train_config.convergence_counter:
train_epoch_loss = total_loss / step
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
train_epoch_loss = total_loss / (step - intermediate_step)
else:
train_epoch_loss = total_loss / step
else:
train_epoch_loss = total_loss / len(train_dataloader)
if train_config.use_peft and train_config.from_peft_checkpoint and epoch == intermediate_epoch:
train_epoch_loss = total_loss / (len(train_dataloader) - intermediate_step)
else:
train_epoch_loss = total_loss / len(train_dataloader)

train_perplexity = torch.exp(train_epoch_loss)

train_prep.append(float(train_perplexity))
train_loss.append(float(train_epoch_loss))

# Update the learning rate as needed
lr_scheduler.step()
should_save_model = train_config.save_model

if train_config.run_validation:
if train_config.enable_ddp:
Expand All @@ -275,14 +307,14 @@ def train(
if train_config.save_metrics:
val_step_loss.extend(temp_val_loss)
val_step_perplexity.extend(temp_step_perplexity)
should_save_model = train_config.save_model and eval_epoch_loss < best_val_loss

if should_save_model:
# saving the adapters after completion of each epoch
if train_config.save_model:
if train_config.enable_ddp:
if dist.get_rank() == 0:
model.module.save_pretrained(train_config.output_dir)
model.module.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}")
else:
model.save_pretrained(train_config.output_dir)
model.save_pretrained(train_config.output_dir + f"/complete_epoch_{epoch + 1}")

if train_config.run_validation:
if eval_epoch_loss < best_val_loss:
Expand All @@ -307,7 +339,6 @@ def train(
val_step_perplexity,
val_prep,
)

avg_epoch_time = sum(epoch_times) / len(epoch_times)
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0
avg_train_prep = sum(train_prep) / len(train_prep)
Expand Down
Loading