In [1]:
import gc
import numpy as np
import sleap_io
import torch
import imageio
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from tqdm import tqdm
from biogtr.datasets.sleap_dataset import SleapDataset
from biogtr.datasets.microscopy_dataset import MicroscopyDataset
from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer
from biogtr.training.losses import AssoLoss

import torchvision.transforms.functional as tvf
from torch.utils.data import Dataset


if __name__ == "__main__":

    train_slp_files = ["/mnt/talmodata/datasets/mot/animal/sleap/benchmarks/flies13/190719_090330_wt_18159206_rig1.2@15000-17560.slp"]
    train_vid_files = ["/mnt/talmodata/datasets/mot/animal/sleap/benchmarks/flies13/190719_090330_wt_18159206_rig1.2@15000-17560.mp4"]
    
    train_micro_labels = ["/mnt/talmodata/datasets/mot/microscopy/ICY/5_cell/5_cells_1_gt.xml"]
    train_micro_tifs = ["/mnt/talmodata/datasets/mot/microscopy/ICY/5_cell/5_cells_1.tif"]
    labels = sleap_io.load_slp(train_slp_files[0])

    device = 'cpu'

    feats = 256

    tracking_transformer = GlobalTrackingTransformer(
        d_model=feats,
        num_encoder_layers=1,
        num_decoder_layers=1,
        dim_feedforward=feats,
        feature_dim_attn_head=feats,
    ).to(device)

    asso_loss = AssoLoss().to(device)

    optimizer = torch.optim.Adam(
        tracking_transformer.parameters(), lr=1e-4, betas=(0.9, 0.999)
    )


    train_ds = SleapDataset(
        train_slp_files,
        train_vid_files,
        padding=5,
        crop_size=128,
        chunk=True,
        clip_length=32,
        crop_type="centroid",
    )
    train_ds = MicroscopyDataset(train_micro_tifs,
                                 train_micro_labels,
                                 "ICY",
                                 padding=5,
                                 crop_size=20,
                                 chunk=True,
                                 clip_length=32
                                )
    
    instances = next(iter(train_ds))

    train_loader = DataLoader(
        train_ds,
        batch_size=1,
        shuffle=True,
        collate_fn=train_ds.no_batching_fn,
        num_workers=0,
    )

    torch.cuda.empty_cache()
    gc.collect()

    num_epochs = 1

    for epoch in range(1, num_epochs + 1):
        print("Epoch: {:02d}/{:02d}".format(epoch, num_epochs))

        losses = []
        _ = tracking_transformer.train()

        print("TRAIN")
        loop = tqdm(
            enumerate(train_loader), position=0, leave=True, total=len(train_loader)
        )
        for i, instances in loop:
            instances = instances[0]  # For batch size of 1.

            asso_preds = tracking_transformer(instances)

            # Compute loss.
            loss = asso_loss(asso_preds, instances)
            print('loss: ', loss)
            losses.append(loss.item())

            loop.set_description(
                "current_loss: {:.5f} | LR: {:.5f}".format(
                    loss.item(), optimizer.param_groups[0]["lr"]
                )
            )
            loop.set_postfix(loss=np.mean(losses))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            break

        break

Epoch: 01/01
TRAIN


current_loss: 1.80920 | LR: 0.00010:   0%|     | 0/7 [00:01<?, ?it/s, loss=1.81]

loss:  tensor(1.8092, grad_fn=<MulBackward0>)


current_loss: 1.80920 | LR: 0.00010:   0%|     | 0/7 [00:01<?, ?it/s, loss=1.81]


In [2]:
{key: tensor.dtype for key, tensor in instances[0].items()}

{'video_id': torch.int64,
 'img_shape': torch.int64,
 'frame_id': torch.int64,
 'num_detected': torch.int64,
 'gt_track_ids': torch.int64,
 'bboxes': torch.float32,
 'crops': torch.float32,
 'features': torch.float32,
 'pred_track_ids': torch.int64,
 'asso_output': torch.float32,
 'matches': torch.float32,
 'traj_score': torch.float32}

In [3]:
instances[0]

{'video_id': tensor([0]),
 'img_shape': tensor([[  1, 512, 512]]),
 'frame_id': tensor([192]),
 'num_detected': tensor([5]),
 'gt_track_ids': tensor([0, 1, 2, 3, 4]),
 'bboxes': tensor([[0.8320, 0.4023, 0.8906, 0.4609],
         [0.8340, 0.4043, 0.8926, 0.4629],
         [0.8164, 0.4121, 0.8750, 0.4707],
         [0.2422, 0.5215, 0.3008, 0.5801],
         [0.4570, 0.7148, 0.5156, 0.7734]]),
 'crops': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.],
           [0., 0., 0.,  ..., 0., 0., 0.]]],
 
 
         [[[0., 0., 0.,  ..., 0., 0.,

In [4]:
# import gc
# import numpy as np
# import sleap_io
# import torch
# import imageio
# import matplotlib.pyplot as plt
# # from dataset import AnimalDataset
# # from pynvml import *
# from torch.utils.data import DataLoader
# from tqdm import tqdm
# from biogtr.models.global_tracking_transformer import GlobalTrackingTransformer
# from biogtr.training.losses import AssoLoss

# import torchvision.transforms.functional as tvf
# from torch.utils.data import Dataset


# def pad_bbox(bbox, padding=16):
#     """Pad bounding box coordinates.

#     Args:
#         bbox: Bounding box in [y1, x1, y2, x2] format.
#         padding: Padding to add to each side in pixels.

#     Returns:
#         Padded bounding box in [y1, x1, y2, x2] format.
#     """
#     y1, x1, y2, x2 = bbox
#     y1, x1 = y1 - padding, x1 - padding
#     y2, x2 = y2 + padding, x2 + padding
#     return [y1, x1, y2, x2]


# def crop_bbox(img, bbox):
#     """Crop an image to a bounding box.

#     Args:
#         img: Image as a tensor of shape (channels, height, width).
#         bbox: Bounding box in [y1, x1, y2, x2] format.

#     Returns:
#         Cropped pixels as tensor of shape (channels, height, width).
#     """
#     # Crop to the bounding box.
#     y1, x1, y2, x2 = bbox
#     crop = tvf.crop(
#         img,
#         top=int(round(y1)),
#         left=int(round(x1)),
#         height=int(round(y2 - y1)),
#         width=int(round(x2 - x1)),
#     )

#     return crop


# def centroid_crop(instance, anchors, crop_size):
#     """Crop bbox around instance centroid. This is useful for ensuring that
#     crops are centered around each instance in the case of incorrect pose
#     estimates

#     Args:
#         instance: a labeled instance in a frame
#         anchor: index of a given anchor point to use as the centroid
#         crop_size: Integer specifying the crop height and width

#     Returns:
#         Bounding box in [y1, x1, y2, x2] format.
#     """

#     for anchor in anchors:
#         cx, cy = instance[anchor].x, instance[anchor].y
#         if not np.isnan(cx):
#             break

#     bbox = [
#         -crop_size / 2 + cy,
#         -crop_size / 2 + cx,
#         crop_size / 2 + cy,
#         crop_size / 2 + cx,
#     ]

#     return bbox


# def resize_and_pad(img, output_size):
#     """Resize and pad an image to fit a square output size.

#     Args:
#         img: Image as a tensor of shape (channels, height, width).
#         output_size: Integer size of height and width of output.

#     Returns:
#         The image zero padded to be of shape (channels, output_size, output_size).
#     """
#     # Figure out how to scale without breaking aspect ratio.
#     img_height, img_width = img.shape[-2:]
#     if img_width < img_height:  # taller
#         crop_height = output_size
#         scale = crop_height / img_height
#         crop_width = int(img_width * scale)
#     else:  # wider
#         crop_width = output_size
#         scale = crop_width / img_width
#         crop_height = int(img_height * scale)

#     # Scale without breaking aspect ratio.
#     img = tvf.resize(img, size=[crop_height, crop_width])

#     # Pad to square.
#     img_height, img_width = img.shape[-2:]
#     hp1 = int((output_size - img_width) / 2)
#     vp1 = int((output_size - img_height) / 2)
#     hp2 = output_size - (img_width + hp1)
#     vp2 = output_size - (img_height + vp1)
#     padding = (hp1, vp1, hp2, vp2)
#     return tvf.pad(img, padding, 0, "constant")

# class AnimalDataset(Dataset):
#     def __init__(
#         self,
#         slp_files,
#         padding=5,
#         crop_size=128,
#         anchor_names=["thorax", "head"],
#         chunk=True,
#         clip_length=500,
#         crop_type="centroid",
#     ):
#         self.slp_files = slp_files
#         self.padding = padding
#         self.crop_size = crop_size
#         self.anchor_names = anchor_names
#         self.chunk = chunk
#         self.clip_length = clip_length
#         self.crop_type = crop_type

#         assert self.crop_type in ["centroid", "pose"], "Invalid crop type!"

#         self.labels = [sleap_io.load_slp(slp_file) for slp_file in self.slp_files]

#         # for label in self.labels:
#             # label.remove_empty_instances(keep_empty_frames=False)

#         self.frame_idx = [np.arange(len(label)) for label in self.labels]

#         if self.chunk:
#             self.chunks = [
#                 [i * self.clip_length for i in range(len(label) // self.clip_length)]
#                 for label in self.labels
#             ]

#             self.chunked_frame_idx, self.label_idx = [], []
#             for i, (split, frame_idx) in enumerate(zip(self.chunks, self.frame_idx)):
#                 frame_idx_split = np.split(frame_idx, split)[1:]
#                 self.chunked_frame_idx.extend(frame_idx_split)
#                 self.label_idx.extend(len(frame_idx_split) * [i])
#         else:
#             self.chunked_frame_idx = self.frame_idx
#             self.label_idx = [i for i in range(len(self.labels))]

#     def __len__(self):
#         return len(self.chunked_frame_idx)

#     def no_batching_fn(self, batch):
#         return batch

#     def __getitem__(self, idx):
#         label_idx = self.label_idx[idx]
#         frame_idx = self.chunked_frame_idx[idx]

#         video = self.labels[label_idx]


#         anchors = [
#             video.skeletons[0].node_names.index(anchor_name)
#             for anchor_name in self.anchor_names
#         ]


#         video_name = os.path.splitext(self.slp_files[label_idx])[0] + ".mp4"

#         vid_reader = imageio.get_reader(video_name, 'ffmpeg')

#         instances = []
#         for i in frame_idx:
#             gt_track_ids, poses, bboxes, crops = [], [], [], []

#             i = int(i)

#             lf = video[i]
#             lf_img = vid_reader.get_data(i)

#             img = tvf.to_tensor(lf_img)

#             _, h, w = img.shape

#             for instance in lf:
#                 # gt_track_ids
#                 gt_track_ids.append(video.tracks.index(instance.track))

#                 # poses
#                 poses.append(np.array(instance.numpy()).astype("float32"))

#                 # bboxes
#                 if self.crop_type == "centroid":
#                     bbox = pad_bbox(
#                         centroid_crop(instance, anchors, self.crop_size),
#                         padding=self.padding,
#                     )
#                 elif self.crop_type == "pose":
#                     points = np.array([[p.x, p.y] for p in instance.points])

#                     min_x = max(np.nanmin(points[:, 0]) - self.padding, 0)
#                     min_y = max(np.nanmin(points[:, 1]) - self.padding, 0)
#                     max_x = min(np.nanmax(points[:, 0]) + self.padding, w)
#                     max_y = min(np.nanmax(points[:, 1]) + self.padding, h)

#                     bbox = [min_x, min_y, max_x, max_y]

#                 bboxes.append(bbox)

#                 # crops
#                 if self.crop_type == "centroid":
#                     crop = crop_bbox(img, bbox)
#                 elif self.crop_type == "pose":
#                     crop = resize_and_pad(crop_bbox(img, bbox), self.crop_size)

#                 crops.append(crop)

#             instances.append(
#                 {
#                     "video_id": torch.from_numpy(np.array([label_idx])),
#                     "img": torch.Tensor(img),
#                     "img_shape": torch.from_numpy(np.array([img.shape])),
#                     "frame_id": torch.from_numpy(np.array([i])),
#                     "num_detected": torch.from_numpy(np.array([len(bboxes)])),
#                     "gt_track_ids": torch.Tensor(gt_track_ids).type(torch.int64),
#                     "poses": torch.Tensor(np.array(poses)),
#                     "bboxes": torch.Tensor(np.array(bboxes)),
#                     "crops": torch.stack(crops),
#                     "features": torch.Tensor([]),
#                     "pred_track_ids": torch.Tensor([-1 for _ in range(len(bboxes))]),
#                     "asso_output": torch.Tensor([]),
#                     "matches": torch.Tensor([]),
#                     "traj_score": torch.Tensor([]),
#                 }
#             )

#         return instances


# if __name__ == "__main__":

#     train_slp_files = ["/Volumes/talmodata/datasets/mot/animal/sleap/benchmarks/flies13/190719_090330_wt_18159206_rig1.2@15000-17560.mp4"]

#     labels = sleap_io.load_slp(train_slp_files[0])

#     all_anchors = labels.skeletons[0].node_names

#     main_anchors = ['thorax', 'head', 'abdomen']
#     anchors = main_anchors + [i for i in all_anchors if i not in main_anchors]

#     device = 'cpu'

#     feats = 256

#     embedding_meta = {
#             'embedding_type': 'fixed_pos',
#             'kwargs': {
#                 'temperature': 2,
#                 'scale': 32,
#                 'normalize': True
#                 # 'learn_pos_num': 16,
#                 # 'over_boxes': True
#             }
#     }

#     tracking_transformer = GlobalTrackingTransformer(
#         d_model=feats,
#         num_encoder_layers=1,
#         num_decoder_layers=1,
#         dim_feedforward=feats,
#         feature_dim_attn_head=feats,
#         embedding_meta=embedding_meta,
#         return_embedding=True
#     ).to(device)

#     asso_loss = AssoLoss().to(device)

#     optimizer = torch.optim.Adam(
#         tracking_transformer.parameters(), lr=1e-4, betas=(0.9, 0.999)
#     )


#     train_ds = AnimalDataset(
#         train_slp_files,
#         padding=5,
#         crop_size=128,
#         anchor_names=anchors,
#         chunk=True,
#         clip_length=32,
#         crop_type="centroid",
#     )

#     instances = next(iter(train_ds))

#     train_loader = DataLoader(
#         train_ds,
#         batch_size=1,
#         shuffle=True,
#         collate_fn=train_ds.no_batching_fn,
#         num_workers=0,
#     )

#     torch.cuda.empty_cache()
#     gc.collect()

#     num_epochs = 1

#     for epoch in range(1, num_epochs + 1):
#         print("Epoch: {:02d}/{:02d}".format(epoch, num_epochs))

#         losses = []
#         _ = tracking_transformer.train()

#         print("TRAIN")
#         loop = tqdm(
#             enumerate(train_loader), position=0, leave=True, total=len(train_loader)
#         )
#         for i, instances in loop:
#             instances = instances[0]  # For batch size of 1.

#             asso_preds, embedding = tracking_transformer(instances)

#             # Compute loss.
#             loss = asso_loss(asso_preds, instances)
#             print('loss: ', loss)
#             losses.append(loss.item())

#             loop.set_description(
#                 "current_loss: {:.5f} | LR: {:.5f}".format(
#                     loss.item(), optimizer.param_groups[0]["lr"]
#                 )
#             )
#             loop.set_postfix(loss=np.mean(losses))

#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()

#             break

#         break