From e660424717daf47d4e511a78b9dda230a2f2a602 Mon Sep 17 00:00:00 2001 From: Anthony Susevski <77211520+asusevski@users.noreply.github.com> Date: Mon, 11 Dec 2023 18:44:23 -0500 Subject: [PATCH] fixed typos (issue 27919) (#27920) * fixed typos (issue 27919) * Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../knowledge_distillation_for_image_classification.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/source/en/tasks/knowledge_distillation_for_image_classification.md b/docs/source/en/tasks/knowledge_distillation_for_image_classification.md index d06b64fbc5a87..8448e53011494 100644 --- a/docs/source/en/tasks/knowledge_distillation_for_image_classification.md +++ b/docs/source/en/tasks/knowledge_distillation_for_image_classification.md @@ -61,8 +61,8 @@ import torch.nn.functional as F class ImageDistilTrainer(Trainer): - def __init__(self, *args, teacher_model=None, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs): + super().__init__(model=student_model, *args, **kwargs) self.teacher = teacher_model self.student = student_model self.loss_function = nn.KLDivLoss(reduction="batchmean") @@ -164,7 +164,7 @@ trainer = ImageDistilTrainer( train_dataset=processed_datasets["train"], eval_dataset=processed_datasets["validation"], data_collator=data_collator, - tokenizer=teacher_extractor, + tokenizer=teacher_processor, compute_metrics=compute_metrics, temperature=5, lambda_param=0.5