In [27]:
from collections import defaultdict
from collections.abc import Sequence

import numpy as np
import torch
import torchvision.transforms.functional as F
from mmcv.transforms import BaseTransform
from mmengine.utils import is_str
from PIL import Image

from mmpretrain.registry import TRANSFORMS
from mmpretrain.structures import DataSample, MultiTaskDataSample
from mmpretrain.datasets import PackInputs

from torch.utils.data import DataLoader

from mmengine.dataset import DefaultSampler, pseudo_collate, default_collate
from mmdet.datasets.objects365 import Objects365V2Dataset
from mmpretrain.datasets.transforms import *
from mmpretrain.models import ClsDataPreprocessor

from mmengine import Config
from mmpretrain.models import build_classifier

from projects.ma_clip.datasets import InstanceDataset, LoadInstanceImage
from projects.ma_clip.models import *
from projects.clip.models import *
from projects.clip.datasets import *

import math
from numbers import Number
from typing import List, Optional, Sequence, Tuple, Union

import torch
import torch.nn.functional as F
from mmengine.model import (BaseDataPreprocessor, ImgDataPreprocessor,
                            stack_batch)

from mmpretrain.registry import MODELS
from mmpretrain.structures import (DataSample, MultiTaskDataSample,
                                   batch_label_to_onehot, cat_batch_labels,
                                   tensor_split)

In [3]:
class PackMultiInputs(PackInputs):

    def transform(self, results: dict) -> dict:
        """Method to pack the input data."""
        packed_results = dict(inputs=[])

        self.input_key = self.input_key \
            if isinstance(self.input_key, Sequence) else [self.input_key]
        for input_key in self.input_key:
            if input_key in results:
                input_ = results[input_key]
                packed_results['inputs'].append(self.format_input(input_))

        data_sample = DataSample()

        # Set default keys
        if 'gt_label' in results:
            data_sample.set_gt_label(results['gt_label'])
        if 'gt_score' in results:
            data_sample.set_gt_score(results['gt_score'])
        if 'mask' in results:
            data_sample.set_mask(results['mask'])

        # Set custom algorithm keys
        for key in self.algorithm_keys:
            if key in results:
                data_sample.set_field(results[key], key)

        # Set meta keys
        for key in self.meta_keys:
            if key in results:
                data_sample.set_field(results[key], key, field_type='metainfo')

        packed_results['data_samples'] = data_sample
        return packed_results

In [29]:
class VisionLanguageDataPreprocessor(ClsDataPreprocessor):

    def forward(self, data: dict, training: bool = False) -> dict:
        """Perform normalization, padding, bgr2rgb conversion and batch
        augmentation based on ``BaseDataPreprocessor``.

        Args:
            data (dict): data sampled from dataloader.
            training (bool): Whether to enable training time augmentation.

        Returns:
            dict: Data in the same format as the model input.
        """
        inputs = self.cast_data(data['inputs'])

        vision, language = inputs
        if isinstance(vision, torch.Tensor):
            # The branch if use `default_collate` as the collate_fn in the
            # dataloader.

            # ------ Vision ------
            # ------ To RGB ------
            if self.to_rgb and vision.size(1) == 3:
                vision = vision.flip(1)

            # --- Normalization ---
            vision = vision.float()
            if self._enable_normalize:
                vision = (vision - self.mean) / self.std

            # ------ Padding -----
            if self.pad_size_divisor > 1:
                h, w = vision.shape[-2:]

                target_h = math.ceil(
                    h / self.pad_size_divisor) * self.pad_size_divisor
                target_w = math.ceil(
                    w / self.pad_size_divisor) * self.pad_size_divisor
                pad_h = target_h - h
                pad_w = target_w - w
                vision = F.pad(vision, (0, pad_w, 0, pad_h), 'constant',
                               self.pad_value)
            # ----- Language -----
            language = language.float()
        else:
            # The branch if use `pseudo_collate` as the collate_fn in the
            # dataloader.

            processed_vision = []
            processed_language = []
            for vision_, language_ in zip(vision, language):
                # ------ Vision ------
                # ------ To RGB ------
                if self.to_rgb and vision_.size(0) == 3:
                    vision_ = vision_.flip(0)

                # --- Normalization ---
                vision_ = vision_.float()
                if self._enable_normalize:
                    vision_ = (vision_ - self.mean) / self.std

                # ----- Language -----
                language_ = language_.float()
                
                processed_vision.append(vision_)
                processed_language.append(language_)
            # Combine padding and stack
            vision = stack_batch(processed_vision, self.pad_size_divisor,
                                 self.pad_value)
            language = stack_batch(processed_language, self.pad_size_divisor,
                                 self.pad_value)

        data_samples = data.get('data_samples', None)
        sample_item = data_samples[0] if data_samples is not None else None

        if isinstance(sample_item, DataSample):
            batch_label = None
            batch_score = None

            if 'gt_label' in sample_item:
                gt_labels = [sample.gt_label for sample in data_samples]
                batch_label, label_indices = cat_batch_labels(gt_labels)
                batch_label = batch_label.to(self.device)
            if 'gt_score' in sample_item:
                gt_scores = [sample.gt_score for sample in data_samples]
                batch_score = torch.stack(gt_scores).to(self.device)
            elif self.to_onehot and 'gt_label' in sample_item:
                assert batch_label is not None, \
                    'Cannot generate onehot format labels because no labels.'
                num_classes = self.num_classes or sample_item.get(
                    'num_classes')
                assert num_classes is not None, \
                    'Cannot generate one-hot format labels because not set ' \
                    '`num_classes` in `data_preprocessor`.'
                batch_score = batch_label_to_onehot(
                    batch_label, label_indices, num_classes).to(self.device)

            # ----- Batch Augmentations ----
            if (training and self.batch_augments is not None
                    and batch_score is not None):
                inputs, batch_score = self.batch_augments(inputs, batch_score)

            # ----- scatter labels and scores to data samples ---
            if batch_label is not None:
                for sample, label in zip(
                        data_samples, tensor_split(batch_label,
                                                   label_indices)):
                    sample.set_gt_label(label)
            if batch_score is not None:
                for sample, score in zip(data_samples, batch_score):
                    sample.set_gt_score(score)
        elif isinstance(sample_item, MultiTaskDataSample):
            data_samples = self.cast_data(data_samples)

        return {'inputs': [vision, language], 'data_samples': data_samples}

In [37]:
pipeline = [
    LoadInstanceImage(with_mask=False, exp_factor=1.2, channel_order='rgb'),
    ResizeEdge(scale=256, edge='short'),
    RandomCrop(crop_size=224),
    RandomFlip(prob=0.5, direction='horizontal'),
    PackMultiInputs(input_key=['img', 'text'])
]
toy_dataset = VisionTemplateLanguageDataset(
    InstanceDataset(
        Objects365V2Dataset(
            data_root='../data/Objects365/Obj365_v2/',
            data_prefix=dict(img='train/'),
            ann_file='debug/train.json'),
        filter_cfg=dict(min_size=32)),
    pipeline=pipeline)

sampler = DefaultSampler(toy_dataset, shuffle=True)
train_loader = DataLoader(dataset=toy_dataset, batch_size=4, sampler=sampler, collate_fn=default_collate)
data_preprocessor = VisionLanguageDataPreprocessor(
    mean=[125.307, 122.961, 113.8575],
    std=[51.5865, 50.847, 51.255])



loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [38]:
for data_batch in train_loader:
    data_batch = data_preprocessor(data_batch, training=True)
    break

In [42]:
data_batch['inputs'][1]

tensor([[[49406.,   320.,  2103.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  1125.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  8757.,  ...,     0.,     0.,     0.],
         ...,
         [49406.,   320.,  1125.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  1125.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  6325.,  ...,     0.,     0.,     0.]],

        [[49406.,   320.,  2103.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  1125.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  8757.,  ...,     0.,     0.,     0.],
         ...,
         [49406.,   320.,  1125.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  1125.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  6325.,  ...,     0.,     0.,     0.]],

        [[49406.,   320.,  2103.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  1125.,  ...,     0.,     0.,     0.],
         [49406.,   320.,  8757.,  ...,     0.,     0.,   

In [2]:
import json

In [3]:
ori_path = '../data/Objects365/Obj365_v2/annotations/pseudo_mask/zhiyuan_objv2_train_pmask_patch0.json'

with open(ori_path, 'r') as f:
    ori_data = json.load(f)

In [6]:
len(ori_data['annotations']), ori_data['annotations'][0]

(4707109,
 {'id': 123,
  'iscrowd': 0,
  'isfake': 0,
  'area': 4003.056219153999,
  'isreflected': 0,
  'bbox': [570.0605468672,
   478.04602053120004,
   136.49035642880006,
   29.328491212799975],
  'image_id': 900003,
  'category_id': 22,
  'segmentation': {'size': [768, 1024],
   'counts': 'Po[=5jg04M2N2N2N1O2N1O1O1O1O010O000000000000000000010O0000000000000001O0000000000000001O01O00000000000000000010O000000000O100O100000001O0001O1O1O00001O00000000000000O10000001O01O0001O0000001O0000000000000000000000000000000000000000001O0000000OO1N210O1O01GfXOLZg03hXOKXg05jXOGYg0792N8H1000N4Ho`_7'},
  'iou': 0.9364637732505798})

In [7]:
save_path = '../data/Objects365/Obj365_v2/annotations/pseudo_mask/zhiyuan_objv2_train_pmask_mini.json'

new_data = dict()
new_data['annotations'] = []
img_list = []
for x in ori_data['annotations'][:2000000]:
    new_data['annotations'].append(x)
    img_list.append(x['image_id'])

new_data['images'] = []
for x in ori_data['images']:
    if x['id'] in img_list:
        new_data['images'].append(x)

new_data['categories'] = ori_data['categories']

with open(save_path, 'w') as f:
    json.dump(new_data, f)

KeyboardInterrupt: 