From 5d03ead5e1182b03a67cc01ae5efa1d59f989278 Mon Sep 17 00:00:00 2001 From: zanshuxun <631763140@qq.com> Date: Sat, 11 Feb 2023 23:22:50 +0800 Subject: [PATCH] fix bug #261 --- deepctr_torch/models/basemodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepctr_torch/models/basemodel.py b/deepctr_torch/models/basemodel.py index cd36340a..abc5b846 100644 --- a/deepctr_torch/models/basemodel.py +++ b/deepctr_torch/models/basemodel.py @@ -76,7 +76,7 @@ def forward(self, X, sparse_feat_refine_weight=None): sparse_embedding_list += varlen_embedding_list - linear_logit = torch.zeros([X.shape[0], 1]).to(self.device) + linear_logit = torch.zeros([X.shape[0], 1]).to(self.weight.device) if len(sparse_embedding_list) > 0: sparse_embedding_cat = torch.cat(sparse_embedding_list, dim=-1) if sparse_feat_refine_weight is not None: