Skip to content

Commit

Permalink
Merge pull request #35 from zhangzhenyu13/main
Browse files Browse the repository at this point in the history
Introducee random teacher layer sets
  • Loading branch information
xiamengzhou committed Nov 7, 2022
2 parents 202d832 + b399e64 commit 756f67b
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,7 @@ def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs
if self.additional_args.do_layer_distill: #! only do layer distill
mlp_z = None
head_layer_z = None
# logger.info(f"zs={zs}")
if "mlp_z" in zs:
mlp_z = zs["mlp_z"].detach().cpu()
if "head_layer_z" in zs:
Expand All @@ -566,13 +567,26 @@ def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs
for layer_num, (t_layer_o, s_layer_o) in enumerate(zip(teacher_layer_output, student_layer_output)):
s_layer_o = self.model.layer_transformation(s_layer_o)
l = mse_loss(t_layer_o, s_layer_o)
if mlp_z[layer_num] > 0:
if mlp_z is None or mlp_z[layer_num] > 0:
layer_loss += l

# distilling layers with a minimal distance
elif self.additional_args.layer_distill_version > 2:
l = []
specified_teacher_layers = [2, 5, 8, 11]
if self.additional_args.layer_distill_version > 4:
specified_teacher_layers = [i for i in range(12)]
if self.additional_args.layer_distill_version ==5:
specified_teacher_layers = sorted(random.sample(specified_teacher_layers, 4))
elif self.additional_args.layer_distill_version ==6:
result_layers_T= []
skip_window = len(specified_teacher_layers)//4
for i in range(0, len(specified_teacher_layers), skip_window):
result_layers_T.append(random.sample(specified_teacher_layers[i:i+skip_window], 1)[0])
specified_teacher_layers = result_layers_T
specified_teacher_layers[0] = max(2, specified_teacher_layers[0])
else:
specified_teacher_layers = [2, 5, 8, 11]
# logger.info(f"sampled teacher layers: {specified_teacher_layers}")
transformed_s_layer_o = [self.model.layer_transformation(
s_layer_o) for s_layer_o in student_layer_output]
specified_teacher_layer_reps = [
Expand All @@ -595,10 +609,10 @@ def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs
if self.additional_args.layer_distill_version == 3:
alignment = torch.argmin(layerwiseloss, dim=1)
#! added the ordering restriction -> to choose the min loss in 4 student layers
elif self.additional_args.layer_distill_version == 4:
elif self.additional_args.layer_distill_version in (3, 4, 5, 6):
last_aligned_layer = 12
alignment = []
for search_index in range(3, -1, -1):
for search_index in range(len(specified_teacher_layers)-1, -1, -1):
indexes = layerwiseloss[search_index].sort()[1]
if existing_layers is not None:
align = indexes[(
Expand All @@ -618,14 +632,14 @@ def calculate_layer_distillation_loss(self, teacher_outputs, student_outputs, zs
f"{self.additional_args.layer_distill_version} version is not specified.")
sys.exit()

layerwise = torch.arange(4).to(device)
layerwise = torch.arange(len(specified_teacher_layers)).to(device)
layer_loss += layerwiseloss[layerwise, alignment].sum() #! layerwise: teacher (specified layers) / alignment: student (min loss layers) / layerwiseloss: [4,12]
if self.global_step % 100 == 0:
logger.info(f"v{self.additional_args.layer_distill_version} Global step: {self.global_step}, Alignment: " + str(alignment))
return layer_loss
else:
return None

def calculate_distillation_loss(self, teacher_outputs, student_outputs, zs):
layer_loss = self.calculate_layer_distillation_loss(teacher_outputs, student_outputs, zs)
distill_loss = layer_loss
Expand Down

0 comments on commit 756f67b

Please sign in to comment.