Skip to content

Commit

Permalink
Multi anchor crop (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaprasad authored May 14, 2024
1 parent 1b7e817 commit 3d6bf9a
Show file tree
Hide file tree
Showing 24 changed files with 607 additions and 244 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ dmypy.json
# lightning
lightning_logs
outputs
logs
*.ckpt

# vscode
.vscode
Expand Down
84 changes: 51 additions & 33 deletions biogtr/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def __init__(
self,
gt_track_id: int = -1,
pred_track_id: int = -1,
bbox: ArrayLike = torch.empty((0, 4)),
crop: ArrayLike = torch.tensor([]),
bbox: ArrayLike = None,
crop: ArrayLike = None,
centroid: dict[str, ArrayLike] = None,
features: ArrayLike = torch.tensor([]),
features: ArrayLike = None,
track_score: float = -1.0,
point_scores: ArrayLike = None,
instance_score: float = -1.0,
Expand Down Expand Up @@ -56,46 +56,58 @@ def __init__(
else:
self._skeleton = skeleton

if not isinstance(bbox, torch.Tensor):
if bbox is None:
self._bbox = torch.empty(1, 0, 4)

elif not isinstance(bbox, torch.Tensor):
self._bbox = torch.tensor(bbox)

else:
self._bbox = bbox

if self._bbox.shape[0] and len(self._bbox.shape) == 1:
self._bbox = self._bbox.unsqueeze(0) # (n_anchors, 4)

if self._bbox.shape[1] and len(self._bbox.shape) == 2:
self._bbox = self._bbox.unsqueeze(0) # (1, n_anchors, 4)

if centroid is not None:
self._centroid = centroid
elif self.bbox.shape[0]:
y1, x1, y2, x2 = self.bbox.squeeze()

elif self.bbox.shape[1]:
y1, x1, y2, x2 = self.bbox.squeeze(dim=0).nanmean(dim=0)
self._centroid = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}

else:
self._centroid = {}

if self._bbox.shape[0] and len(self._bbox.shape) == 1:
self._bbox = self._bbox.unsqueeze(0)

if not isinstance(crop, torch.Tensor):
if crop is None:
self._crop = torch.tensor([])
elif not isinstance(crop, torch.Tensor):
self._crop = torch.tensor(crop)
else:
self._crop = crop

if len(self._crop.shape) == 2:
self._crop = self._crop.unsqueeze(0).unsqueeze(0)
elif len(self._crop.shape) == 3:
self._crop = self._crop.unsqueeze(0)
if len(self._crop.shape) == 2: # (h, w)
self._crop = self._crop.unsqueeze(0) # (c, h, w)
if len(self._crop.shape) == 3:
self._crop = self._crop.unsqueeze(0) # (1, c, h, w)

if not isinstance(features, torch.Tensor):
if features is None:
self._features = torch.tensor([])
elif not isinstance(features, torch.Tensor):
self._features = torch.tensor(features)
else:
self._features = features

if self._features.shape[0] and len(self._features.shape) == 1:
self._features = self._features.unsqueeze(0)
if self._features.shape[0] and len(self._features.shape) == 1: # (d,)
self._features = self._features.unsqueeze(0) # (1, d)

if pose is not None:
self._pose = pose

elif self.bbox.shape[0]:

y1, x1, y2, x2 = self.bbox.squeeze()
elif self.bbox.shape[1]:
y1, x1, y2, x2 = self.bbox.squeeze(dim=0).mean(dim=0)
self._pose = {"centroid": np.array([(x1 + x2) / 2, (y1 + y2) / 2])}

else:
Expand Down Expand Up @@ -287,14 +299,16 @@ def bbox(self, bbox: ArrayLike) -> None:

if self._bbox.shape[0] and len(self._bbox.shape) == 1:
self._bbox = self._bbox.unsqueeze(0)
if self._bbox.shape[1] and len(self._bbox.shape) == 2:
self._bbox = self._bbox.unsqueeze(0)

def has_bbox(self) -> bool:
"""Determine if the instance has a bbox.
Returns:
True if the instance has a bounding box, false otherwise.
"""
if self._bbox.shape[0] == 0:
if self._bbox.shape[1] == 0:
return False
else:
return True
Expand All @@ -318,14 +332,15 @@ def centroid(self, centroid: dict[str, ArrayLike]) -> None:
self._centroid = centroid

@property
def anchor(self) -> str:
def anchor(self) -> list[str]:
"""The anchor node name around which the crop was formed.
Returns:
the node name of the anchor around which the crop was formed
the list of anchors around which each crop was formed
the list of anchors around which each crop was formed
"""
if self.centroid:
return list(self.centroid.keys())[0]
return list(self.centroid.keys())
return ""

@property
Expand Down Expand Up @@ -353,8 +368,8 @@ def crop(self, crop: ArrayLike) -> None:
self._crop = crop

if len(self._crop.shape) == 2:
self._crop = self._crop.unsqueeze(0).unsqueeze(0)
elif len(self._crop.shape) == 3:
self._crop = self._crop.unsqueeze(0)
if len(self._crop.shape) == 3:
self._crop = self._crop.unsqueeze(0)

def has_crop(self) -> bool:
Expand Down Expand Up @@ -528,8 +543,8 @@ def __init__(
video_id: int,
frame_id: int,
vid_file: str = "",
img_shape: ArrayLike = [0, 0, 0],
instances: List[Instance] = [],
img_shape: ArrayLike = None,
instances: List[Instance] = None,
asso_output: ArrayLike = None,
matches: tuple = None,
traj_score: Union[ArrayLike, dict] = None,
Expand Down Expand Up @@ -559,13 +574,16 @@ def __init__(
self._video = sio.Video(vid_file)
except ValueError:
self._video = vid_file

if isinstance(img_shape, torch.Tensor):
if img_shape is None:
self._img_shape = torch.tensor([0, 0, 0])
elif isinstance(img_shape, torch.Tensor):
self._img_shape = img_shape
else:
self._img_shape = torch.tensor([img_shape])

self._instances = instances
if instances is None:
self.instances = []
else:
self._instances = instances

self._asso_output = asso_output
self._matches = matches
Expand Down Expand Up @@ -631,7 +649,7 @@ def to(self, map_location: str):
return self

def to_slp(
self, track_lookup: dict[int : sio.Track] = {}
self, track_lookup: dict[int, sio.Track] = {}
) -> tuple[sio.LabeledFrame, dict[int, sio.Track]]:
"""Convert Frame to sleap_io.LabeledFrame object.
Expand Down
120 changes: 91 additions & 29 deletions biogtr/datasets/sleap_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(
video_files: list[str],
padding: int = 5,
crop_size: int = 128,
anchor: str = "",
anchors: Union[int, list[str], str] = "",
chunk: bool = True,
clip_length: int = 500,
mode: str = "train",
Expand All @@ -39,12 +39,15 @@ def __init__(
video_files: a list of paths to video files
padding: amount of padding around object crops
crop_size: the size of the object crops
anchor: the name of the anchor keypoint to be used as centroid for cropping.
If unavailable then crop around the midpoint between all visible anchors.
anchors: One of:
* a string indicating a single node to center crops around
* a list of skeleton node names to be used as the center of crops
* an int indicating the number of anchors to randomly select
If unavailable then crop around the midpoint between all visible anchors.
chunk: whether or not to chunk the dataset into batches
clip_length: the number of frames in each chunk
mode: `train` or `val`. Determines whether this dataset is used for
training or validation. Currently doesn't affect dataset logic
training or validation.
augmentations: An optional dict mapping augmentations to parameters. The keys
should map directly to augmentation classes in albumentations. Example:
augmentations = {
Expand Down Expand Up @@ -78,7 +81,19 @@ def __init__(
self.mode = mode.lower()
self.n_chunks = n_chunks
self.seed = seed
self.anchor = anchor.lower()

if isinstance(anchors, int):
self.anchors = anchors
elif isinstance(anchors, str):
self.anchors = [anchors.lower()]
else:
self.anchors = [anchor.lower() for anchor in anchors]

if (
isinstance(self.anchors, list) and len(self.anchors) == 0
) or self.anchors == 0:
raise ValueError(f"Must provide at least one anchor but got {self.anchors}")

self.verbose = verbose

# if self.seed is not None:
Expand Down Expand Up @@ -165,6 +180,18 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
print(f"Could not read frame {frame_ind} from {video_name} due to {e}")
continue

if len(img.shape) == 2:
img = img.expand_dims(-1)
h, w, c = img.shape

if c == 1:
img = np.concatenate(
[img, img, img], axis=-1
) # convert to grayscale to rgb

if np.issubdtype(img.dtype, np.integer): # convert int to float
img = img.astype(np.float32) / 255

for instance in lf:
if instance.track is not None:
gt_track_id = video.tracks.index(instance.track)
Expand Down Expand Up @@ -247,41 +274,76 @@ def get_instances(self, label_idx: List[int], frame_idx: List[int]) -> list[dict
pose = shown_poses[j]

"""Check for anchor"""
if self.anchor == "random":
anchors = list(pose.keys()) + ["midpoint"]
anchor = np.random.choice(anchors)
elif self.anchor in pose:
anchor = self.anchor
crops = []
boxes = []
centroids = {}

if isinstance(self.anchors, int):
anchors_to_choose = list(pose.keys()) + ["midpoint"]
anchors = np.random.choice(anchors_to_choose, self.anchors)
else:
if self.verbose:
warnings.warn(
f"{self.anchor} not in {[key for key in pose.keys()]}! Defaulting to midpoint"
)
anchor = "midpoint"
anchors = self.anchors

if anchor != "midpoint":
centroid = pose[anchor]
for anchor in anchors:
if anchor == "midpoint" or anchor == "centroid":
centroid = np.nanmean(np.array(list(pose.values())), axis=0)

elif anchor in pose:
centroid = np.array(pose[anchor])
if np.isnan(centroid).any():
centroid = np.array([np.nan, np.nan])

if np.isnan(centroid).any():
elif anchor not in pose and len(anchors) == 1:
anchor = "midpoint"
centroid = np.nanmean(np.array(list(pose.values())), axis=0)
else:
# print(f'{self.anchor} not an available option amongst {pose.keys()}. Using midpoint')
centroid = np.nanmean(np.array(list(pose.values())), axis=0)

bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
padding=self.padding,
)
elif anchor in pose:
centroid = np.array(pose[anchor])
if np.isnan(centroid).any():
centroid = np.array([np.nan, np.nan])

elif anchor not in pose and len(anchors) == 1:
anchor = "midpoint"
centroid = np.nanmean(np.array(list(pose.values())), axis=0)

else:
centroid = np.array([np.nan, np.nan])

if np.isnan(centroid).all():
bbox = torch.tensor([np.nan, np.nan, np.nan, np.nan])

else:
bbox = data_utils.pad_bbox(
data_utils.get_bbox(centroid, self.crop_size),
padding=self.padding,
)

if bbox.isnan().all():
crop = torch.zeros(
c,
self.crop_size + 2 * self.padding,
self.crop_size + 2 * self.padding,
dtype=img.dtype,
)
else:
crop = data_utils.crop_bbox(img, bbox)

crops.append(crop)
centroids[anchor] = centroid
boxes.append(bbox)

if len(crops) > 0:
crops = torch.concat(crops, dim=0)

crop = data_utils.crop_bbox(img, bbox)
if len(boxes) > 0:
boxes = torch.stack(boxes, dim=0)

instance = Instance(
gt_track_id=gt_track_ids[j],
pred_track_id=-1,
crop=crop,
centroid={anchor: centroid},
bbox=bbox,
crop=crops,
centroid=centroids,
bbox=boxes,
skeleton=skeleton,
pose=poses[j],
point_scores=point_scores[j],
Expand Down
Loading

0 comments on commit 3d6bf9a

Please sign in to comment.