From 06825388bd84c645530a596cce3aaf60f90ecc89 Mon Sep 17 00:00:00 2001 From: "Bird.Z" Date: Thu, 3 Nov 2022 15:47:55 +0800 Subject: [PATCH 1/2] Update trainer.py add more version inspired by rail-kd --- trainer/trainer.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/trainer/trainer.py b/trainer/trainer.py index 5da2dd1..0749918 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -548,11 +548,12 @@ def save_model(self, output_dir: Optional[str] = None): self.model.save_pretrained(output_dir) - def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs): + def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs): mse_loss = torch.nn.MSELoss(reduction="mean") if self.additional_args.do_layer_distill: #! only do layer distill mlp_z = None head_layer_z = None + # logger.info(f"zs={zs}") if "mlp_z" in zs: mlp_z = zs["mlp_z"].detach().cpu() if "head_layer_z" in zs: @@ -566,13 +567,26 @@ def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs for layer_num, (t_layer_o, s_layer_o) in enumerate(zip(teacher_layer_output, student_layer_output)): s_layer_o = self.model.layer_transformation(s_layer_o) l = mse_loss(t_layer_o, s_layer_o) - if mlp_z[layer_num] > 0: + if mlp_z is None or mlp_z[layer_num] > 0: layer_loss += l # distilling layers with a minimal distance elif self.additional_args.layer_distill_version > 2: l = [] - specified_teacher_layers = [2, 5, 8, 11] + if self.additional_args.layer_distill_version > 4: + specified_teacher_layers = [i for i in range(12)] + if self.additional_args.layer_distill_version ==5: + specified_teacher_layers = sorted(random.sample(specified_teacher_layers, 4)) + elif self.additional_args.layer_distill_version ==6: + result_layers_T= [] + skip_window = len(specified_teacher_layers)//4 + for i in range(0, len(specified_teacher_layers), skip_window): + result_layers_T.append(random.sample(specified_teacher_layers[i:i+skip_window], 1)[0]) + specified_teacher_layers = result_layers_T + specified_teacher_layers[0] = max(2, specified_teacher_layers[0]) + else: + specified_teacher_layers = [2, 5, 8, 11] + # logger.info(f"sampled teacher layers: {specified_teacher_layers}") transformed_s_layer_o = [self.model.layer_transformation( s_layer_o) for s_layer_o in student_layer_output] specified_teacher_layer_reps = [ @@ -595,10 +609,10 @@ def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs if self.additional_args.layer_distill_version == 3: alignment = torch.argmin(layerwiseloss, dim=1) #! added the ordering restriction -> to choose the min loss in 4 student layers - elif self.additional_args.layer_distill_version == 4: + elif self.additional_args.layer_distill_version in (3, 4, 5, 6): last_aligned_layer = 12 alignment = [] - for search_index in range(3, -1, -1): + for search_index in range(len(specified_teacher_layers)-1, -1, -1): indexes = layerwiseloss[search_index].sort()[1] if existing_layers is not None: align = indexes[( @@ -618,14 +632,14 @@ def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs f"{self.additional_args.layer_distill_version} version is not specified.") sys.exit() - layerwise = torch.arange(4).to(device) + layerwise = torch.arange(len(specified_teacher_layers)).to(device) layer_loss += layerwiseloss[layerwise, alignment].sum() #! layerwise: teacher (specified layers) / alignment: student (min loss layers) / layerwiseloss: [4,12] if self.global_step % 100 == 0: logger.info(f"v{self.additional_args.layer_distill_version} Global step: {self.global_step}, Alignment: " + str(alignment)) return layer_loss else: return None - + def calculate_distillation_loss(self, teacher_outputs, student_outputs, zs): layer_loss = self.calculate_layer_distillation_loss(teacher_outputs, student_outputs, zs) distill_loss = layer_loss From b399e64a865697a02d25181fe4ca137dd2ce9be6 Mon Sep 17 00:00:00 2001 From: "Bird.Z" Date: Thu, 3 Nov 2022 15:48:45 +0800 Subject: [PATCH 2/2] Update trainer.py --- trainer/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer/trainer.py b/trainer/trainer.py index 0749918..8523bc4 100644 --- a/trainer/trainer.py +++ b/trainer/trainer.py @@ -548,7 +548,7 @@ def save_model(self, output_dir: Optional[str] = None): self.model.save_pretrained(output_dir) - def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs): + def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs): mse_loss = torch.nn.MSELoss(reduction="mean") if self.additional_args.do_layer_distill: #! only do layer distill mlp_z = None