Skip to content

Commit

Permalink
fix warning for pytorch 1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
ranahanocka committed Sep 2, 2019
1 parent cf93099 commit feaa6c1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions models/layers/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def remove_edge(self, edge_id):
self.ve[v].remove(edge_id)

def clean(self, edges_mask, groups):
torch_mask = torch.from_numpy(edges_mask.copy())
edges_mask = edges_mask.astype(bool)
torch_mask = torch.from_numpy(edges_mask.copy())
self.gemm_edges = self.gemm_edges[edges_mask]
self.edges = self.edges[edges_mask]
self.sides = self.sides[edges_mask]
Expand Down Expand Up @@ -155,7 +155,7 @@ def init_history(self):
'occurrences': [],
'old2current': np.arange(self.edges_count, dtype=np.int32),
'current2old': np.arange(self.edges_count, dtype=np.int32),
'edges_mask': [torch.ones(self.edges_count,dtype=torch.uint8)],
'edges_mask': [torch.ones(self.edges_count,dtype=torch.bool)],
'edges_count': [self.edges_count],
}
if self.export_folder:
Expand Down
2 changes: 1 addition & 1 deletion models/layers/mesh_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __pool_main(self, mesh_index):
# recycle = []
# last_queue_len = len(queue)
last_count = mesh.edges_count + 1
mask = np.ones(mesh.edges_count, dtype=np.uint8)
mask = np.ones(mesh.edges_count, dtype=np.bool)
edge_groups = MeshUnion(mesh.edges_count, self.__fe.device)
while mesh.edges_count > self.__out_target:
value, edge_id = heappop(queue)
Expand Down

0 comments on commit feaa6c1

Please sign in to comment.