Skip to content

Commit

Permalink
fix bug #263
Browse files Browse the repository at this point in the history
  • Loading branch information
zanshuxun committed Feb 11, 2023
1 parent 5d03ead commit 2149673
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion deepctr_torch/layers/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def forward(self, inputs):
moe_out = torch.matmul(output_of_experts, gating_score_of_experts.softmax(1))
x_l = moe_out + x_l # (bs, in_features, 1)

x_l = x_l.squeeze() # (bs, in_features)
x_l = x_l.squeeze(-1) # (bs, in_features)
return x_l


Expand Down
4 changes: 2 additions & 2 deletions deepctr_torch/models/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
x = x_train.to(self.device).float()
y = y_train.to(self.device).float()

y_pred = model(x).squeeze()
y_pred = model(x)

optim.zero_grad()
if isinstance(loss_func, list):
Expand All @@ -251,7 +251,7 @@ def fit(self, x=None, y=None, batch_size=None, epochs=1, verbose=1, initial_epoc
loss = sum(
[loss_func[i](y_pred[:, i], y[:, i], reduction='sum') for i in range(self.num_tasks)])
else:
loss = loss_func(y_pred, y.squeeze(), reduction='sum')
loss = loss_func(y_pred, y, reduction='sum')
reg_loss = self.get_regularization_loss()

total_loss = loss + reg_loss + self.aux_loss
Expand Down
2 changes: 1 addition & 1 deletion deepctr_torch/models/multitask/mmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def forward(self, X):
else:
gate_dnn_out = self.gate_dnn_final_layer[i](dnn_input)
gate_mul_expert = torch.matmul(gate_dnn_out.softmax(1).unsqueeze(1), expert_outs) # (bs, 1, dim)
mmoe_outs.append(gate_mul_expert.squeeze())
mmoe_outs.append(gate_mul_expert.squeeze(1))

# tower dnn (task-specific)
task_outs = []
Expand Down
4 changes: 2 additions & 2 deletions deepctr_torch/models/multitask/ple.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def cgc_net(self, inputs, level_num):
else:
gate_dnn_out = self.specific_gate_dnn_final_layer[level_num][i](inputs[i])
gate_mul_expert = torch.matmul(gate_dnn_out.softmax(1).unsqueeze(1), cur_experts_outputs) # (bs, 1, dim)
cgc_outs.append(gate_mul_expert.squeeze())
cgc_outs.append(gate_mul_expert.squeeze(1))

# gates for shared experts
cur_experts_outputs = specific_expert_outputs + shared_expert_outputs
Expand All @@ -189,7 +189,7 @@ def cgc_net(self, inputs, level_num):
else:
gate_dnn_out = self.shared_gate_dnn_final_layer[level_num](inputs[-1])
gate_mul_expert = torch.matmul(gate_dnn_out.softmax(1).unsqueeze(1), cur_experts_outputs) # (bs, 1, dim)
cgc_outs.append(gate_mul_expert.squeeze())
cgc_outs.append(gate_mul_expert.squeeze(1))

return cgc_outs

Expand Down

0 comments on commit 2149673

Please sign in to comment.