Skip to content

Commit

Permalink
Merge pull request #667 from hesamossanloo/main
Browse files Browse the repository at this point in the history
Update data type for bounding box coordinates to float32
  • Loading branch information
ethanwhite committed May 11, 2024
2 parents a03a481 + 0f12089 commit dfd9f31
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepforest/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def __getitem__(self, idx):
self.image_names[idx]]
targets = {}
targets["boxes"] = image_annotations[["xmin", "ymin", "xmax",
"ymax"]].values.astype(float)
"ymax"]].values.astype("float32")

# Labels need to be encoded
targets["labels"] = image_annotations.label.apply(
Expand All @@ -112,7 +112,7 @@ def __getitem__(self, idx):
labels = torch.from_numpy(targets["labels"])
# channels last
image = np.rollaxis(image, 2, 0)
image = torch.from_numpy(image)
image = torch.from_numpy(image).float()
targets = {"boxes": boxes, "labels": labels}
return self.image_names[idx], image, targets

Expand All @@ -122,7 +122,7 @@ def __getitem__(self, idx):
image = augmented["image"]

boxes = np.array(augmented["bboxes"])
boxes = torch.from_numpy(boxes)
boxes = torch.from_numpy(boxes).float()
labels = np.array(augmented["category_ids"])
labels = torch.from_numpy(labels)
targets = {"boxes": boxes, "labels": labels}
Expand Down

0 comments on commit dfd9f31

Please sign in to comment.