Skip to content

Commit

Permalink
fixed typos (issue 27919) (huggingface#27920)
Browse files Browse the repository at this point in the history
* fixed typos (issue 27919)

* Update docs/source/en/tasks/knowledge_distillation_for_image_classification.md

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
  • Loading branch information
asusevski and amyeroberts committed Dec 11, 2023
1 parent e5079b0 commit e660424
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ import torch.nn.functional as F


class ImageDistilTrainer(Trainer):
def __init__(self, *args, teacher_model=None, **kwargs):
super().__init__(*args, **kwargs)
def __init__(self, teacher_model=None, student_model=None, temperature=None, lambda_param=None, *args, **kwargs):
super().__init__(model=student_model, *args, **kwargs)
self.teacher = teacher_model
self.student = student_model
self.loss_function = nn.KLDivLoss(reduction="batchmean")
Expand Down Expand Up @@ -164,7 +164,7 @@ trainer = ImageDistilTrainer(
train_dataset=processed_datasets["train"],
eval_dataset=processed_datasets["validation"],
data_collator=data_collator,
tokenizer=teacher_extractor,
tokenizer=teacher_processor,
compute_metrics=compute_metrics,
temperature=5,
lambda_param=0.5
Expand Down

0 comments on commit e660424

Please sign in to comment.