diff --git a/qlib/contrib/model/pytorch_gats.py b/qlib/contrib/model/pytorch_gats.py index 7f379c3b96..7c2c99432c 100644 --- a/qlib/contrib/model/pytorch_gats.py +++ b/qlib/contrib/model/pytorch_gats.py @@ -260,7 +260,7 @@ def fit( if self.model_path is not None: self.logger.info("Loading pretrained model...") - pretrained_model.load_state_dict(torch.load(self.model_path)) + pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) model_dict = self.GAT_model.state_dict() pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict} diff --git a/qlib/contrib/model/pytorch_gats_ts.py b/qlib/contrib/model/pytorch_gats_ts.py index 4017192750..53a7817e27 100644 --- a/qlib/contrib/model/pytorch_gats_ts.py +++ b/qlib/contrib/model/pytorch_gats_ts.py @@ -276,7 +276,7 @@ def fit( if self.model_path is not None: self.logger.info("Loading pretrained model...") - pretrained_model.load_state_dict(torch.load(self.model_path)) + pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device)) model_dict = self.GAT_model.state_dict() pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict} diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 780dc4b91d..7086cdb5a6 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -257,7 +257,7 @@ def fit( self.scheduler.step(cur_loss_val) # restore the optimal parameters after training - self.dnn_model.load_state_dict(torch.load(save_path)) + self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device)) if self.use_gpu: torch.cuda.empty_cache() @@ -296,7 +296,7 @@ def load(self, buffer, **kwargs): ] _model_path = os.path.join(model_dir, _model_name) # Load model - self.dnn_model.load_state_dict(torch.load(_model_path)) + self.dnn_model.load_state_dict(torch.load(_model_path, map_location=self.device)) self.fitted = True diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 504048210b..e0e2093e8f 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -160,7 +160,7 @@ def fit( self.logger.info("Pretrain...") self.pretrain_fn(dataset, self.pretrain_file) self.logger.info("Load Pretrain model") - self.tabnet_model.load_state_dict(torch.load(self.pretrain_file)) + self.tabnet_model.load_state_dict(torch.load(self.pretrain_file, map_location=self.device)) # adding one more linear layer to fit the final output dimension self.tabnet_model = FinetuneModel(self.out_dim, self.final_out_dim, self.tabnet_model).to(self.device) diff --git a/qlib/contrib/model/pytorch_tcts.py b/qlib/contrib/model/pytorch_tcts.py index 7cd59be9b4..d813ae01f1 100644 --- a/qlib/contrib/model/pytorch_tcts.py +++ b/qlib/contrib/model/pytorch_tcts.py @@ -350,9 +350,9 @@ def training( break print("best loss:", best_loss, "@", best_epoch) - best_param = torch.load(save_path + "_fore_model.bin") + best_param = torch.load(save_path + "_fore_model.bin", map_location=self.device) self.fore_model.load_state_dict(best_param) - best_param = torch.load(save_path + "_weight_model.bin") + best_param = torch.load(save_path + "_weight_model.bin", map_location=self.device) self.weight_model.load_state_dict(best_param) self.fitted = True