In [None]:
import os
import clip
import torch
import copy
import math
from torch import nn
from tqdm import tqdm
from typing import Optional
import torch.nn.functional as F
from torch import Tensor
from torch.nested import as_nested_tensor
from torchvision.datasets import CocoDetection
from pycocotools.coco import COCO
from torch.utils.data import Dataset, DataLoader

from model import build_model
from detr.models.detr import PostProcess, SetCriterion
from detr.models.matcher import build_matcher
import detr.util.misc as utils
from cocoeval import CocoEvaluator

In [None]:
class DetrClipDataset(Dataset):
    def __init__(self, dataset_dir, coco_dir, split) -> None:
        self.data = []
        assert split in ['train', 'val']
        ann_filename = 'instances_train2017.json' if split == 'train' else 'instances_val2017.json'
        self.coco = COCO(os.path.join(coco_dir, 'annotations', ann_filename))
        print('Loading preprocessed samples into memory...')
        sample_filenames = [fn for fn in os.listdir(dataset_dir) if f'detr_clip_{split}' in fn]
        for fn in tqdm(sample_filenames):
            self.data += torch.load(os.path.join(dataset_dir, fn))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        detr_f, clip_f = item['detr_f'], item['clip_f']
        outputs, targets = item['outputs'],item['targets']
        detr_logits, detr_boxes = outputs['pred_logits'], outputs['pred_boxes']
        return detr_f, clip_f.float(), detr_logits, detr_boxes, targets

def collate_fn(batch):
    batch = list(zip(*batch))
    batch = [torch.stack(item) for item in batch[:-1]] + [batch[-1]]
    return batch

In [None]:
dataset_train = DetrClipDataset('data', 'coco', 'train')
dataloader_train = DataLoader(dataset_train, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
dataloader_train_iter = iter(dataloader_train)
detr_f, clip_f, detr_logits, detr_boxes, targets = next(dataloader_train_iter)

In [None]:
detr_f.size(), clip_f.size(), detr_logits.size(), detr_boxes.size()

In [None]:
detr_f.dtype, clip_f.dtype

In [None]:
class Args:
    def __init__(self) -> None:
        self.nhead = 8
        self.num_layers = 6
        self.dim_feedforward = 2048
        self.dropout = 0.1
        self.set_cost_class = 1
        self.set_cost_bbox = 5
        self.set_cost_giou = 2
        self.bbox_loss_coef = 5
        self.giou_loss_coef = 2
        self.eos_coef = 0.1
        self.lr = 1e-4
        self.weight_decay = 1e-4
        self.lr_drop = 200

args = Args()
model = build_model(256, 512, torch.randn(92, 512), args)

out = model(clip_f, detr_f)
out.size()

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr,
                                  weight_decay=args.weight_decay)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)

weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef, 'loss_giou': args.giou_loss_coef}
losses = ['labels', 'boxes', 'cardinality']

matcher = build_matcher(args)
criterion = SetCriterion(91, matcher, weight_dict, args.eos_coef, losses=losses)
postprocessors = PostProcess()

In [None]:
epoch = 0
max_norm = 0
device = 'mps'

model.to(device)
criterion.to(device)


model.train()
criterion.train()
metric_logger = utils.MetricLogger(delimiter="  ")
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 10

for detr_fs, clip_fs, detr_logits, detr_boxes, targets in metric_logger.log_every(dataloader_train, print_freq, header):
    detr_fs, clip_fs = detr_fs.to(device), clip_fs.to(device)
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

    logits = model(clip_fs, detr_fs)
    outputs = {'pred_logits': logits, 'pred_boxes': detr_boxes.to(device)}
    loss_dict = criterion(outputs, targets)
    weight_dict = criterion.weight_dict
    losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)

    # reduce losses over all GPUs for logging purposes
    loss_dict_reduced = utils.reduce_dict(loss_dict)
    loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                    for k, v in loss_dict_reduced.items()}
    loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                for k, v in loss_dict_reduced.items() if k in weight_dict}
    losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())

    loss_value = losses_reduced_scaled.item()

    # if not math.isfinite(loss_value):
    #     print("Loss is {}, stopping training".format(loss_value))
    #     print(loss_dict_reduced)
    #     sys.exit(1)

    optimizer.zero_grad()
    losses.backward()
    if max_norm > 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
    optimizer.step()

    metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
    metric_logger.update(class_error=loss_dict_reduced['class_error'])
    metric_logger.update(lr=optimizer.param_groups[0]["lr"])
    break
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)

In [None]:
dataset_val = DetrClipDataset('data', 'coco', 'val')
dataloader_val = DataLoader(dataset_train, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
data_loader = dataloader_val
base_ds = dataset_train.coco
device = 'mps'

model.to(device)
criterion.to(device)

with torch.no_grad():
    model.eval()
    criterion.eval()

    metric_logger = utils.MetricLogger(delimiter="  ")
    metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
    header = 'Test:'

    iou_types = ('bbox',)
    coco_evaluator = CocoEvaluator(base_ds, iou_types)

    panoptic_evaluator = None

    for detr_fs, clip_fs, detr_logits, detr_boxes, targets in metric_logger.log_every(data_loader, 10, header):
        detr_fs, clip_fs = detr_fs.to(device), clip_fs.to(device)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        logits = model(clip_fs, detr_fs)
        outputs = {'pred_logits': logits, 'pred_boxes': detr_boxes.to(device)}

        loss_dict = criterion(outputs, targets)
        weight_dict = criterion.weight_dict

        # reduce losses over all GPUs for logging purposes
        loss_dict_reduced = utils.reduce_dict(loss_dict)
        loss_dict_reduced_scaled = {k: v * weight_dict[k]
                                    for k, v in loss_dict_reduced.items() if k in weight_dict}
        loss_dict_reduced_unscaled = {f'{k}_unscaled': v
                                      for k, v in loss_dict_reduced.items()}
        metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
                             **loss_dict_reduced_scaled,
                             **loss_dict_reduced_unscaled)
        metric_logger.update(class_error=loss_dict_reduced['class_error'])

        orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
        results = postprocessors(outputs, orig_target_sizes)
        
        res = {target['image_id'].item(): output for target, output in zip(targets, results)}
        if coco_evaluator is not None:
            coco_evaluator.update(res)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    if coco_evaluator is not None:
        coco_evaluator.synchronize_between_processes()

    # accumulate predictions from all images
    if coco_evaluator is not None:
        coco_evaluator.accumulate()
        coco_evaluator.summarize()
    stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
    if coco_evaluator is not None:
        stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
    # return stats, coco_evaluator