In [48]:
import torch
x1 = torch.tensor([[False, False, False, False, False,  True, False,  True, False,  True,
         False, False, False, False, False,  True, False,  True, False,  True,
         False, False, False, False, False,  True, False,  True, False,  True]])
x2 = torch.tensor([[False, False, False, False, False,  True, False,  True, False,  True,
         False, False, False, False, False,  True, False,  True, False,  True,
         False, False, False, False, False,  True, False,  True, False,  True]])
x3 = torch.tensor([[False, False, False, False, False,  True, False,  True, False,  True,
         False, False, False, False, False,  True, False,  True, False,  True,
         False, False, False, False, False,  True, False,  True, False,  True]])
print(torch.all(x1 == x2))
print(torch.all(x1 == x3))
print(torch.all(x2 == x3))

tensor(True)
tensor(True)
tensor(True)


In [12]:
import torch
x = torch.tensor([6, 1,2,3,3,3,4,5])
torch.unique(x, sorted=False)

tensor([1, 2, 3, 4, 5, 6])

In [45]:
from typing import Tuple

from numpy import ones_like


def _generate_random_perturbation_masks(
    total_perturbations_per_feature_group: int,
    feature_masks: Tuple[torch.Tensor, ...],
    perturbation_probability: float = 0.1,
    device: torch.device = torch.device("cpu"),
) -> tuple[torch.Tensor, ...]:
    if not isinstance(feature_masks, tuple):
        feature_masks = (feature_masks,)

    perturbation_masks = tuple()
    for feature_masks_per_type in feature_masks: # unpack the tuples of feature types
        perturbations_per_input_type = []
        for feature_masks_per_sample in feature_masks_per_type: # unpack the samples in a batch
            feature_count = len(torch.unique(feature_masks_per_sample))
            perturbation_masks_per_sample = torch.zeros((total_perturbations_per_feature_group, *feature_masks_per_sample.shape), dtype=torch.bool, device=device)
            for i in range(total_perturbations_per_feature_group):
                feature_drop_mask = torch.randperm(feature_count, device=device)[:int(perturbation_probability * feature_count)]
                for feature_idx in feature_drop_mask:
                    perturbation_masks_per_sample[i][perturbation_masks_per_sample[i] == feature_idx] = 1
            perturbations_per_input_type.append(
                perturbation_masks_per_sample
            )
        perturbation_masks += (torch.cat(perturbations_per_input_type, dim=0),)
    return perturbation_masks

def create_labelled_patch_grid(
    images: torch.Tensor, grid_size: int = 16
) -> torch.Tensor:
    feature_masks = []
    for image in images:
        # image dimensions are C x H x H
        dim_x, dim_y = image.shape[1] // grid_size, image.shape[2] // grid_size
        mask = (
            torch.arange(dim_x * dim_y)
            .view((dim_x, dim_y))
            .repeat_interleave(grid_size, dim=0)
            .repeat_interleave(grid_size, dim=1)
            .long()
            .unsqueeze(0)
        )
        feature_masks.append(mask)
    return torch.stack(feature_masks)

total_perturbations_per_feature_group = 10
images = torch.randn((2, 3, 224, 224))
feature_masks = create_labelled_patch_grid(images, 16)

_generate_random_perturbation_masks(10, feature_masks, 0.1)

perturbation_masks 1 torch.Size([20, 1, 224, 224])


(tensor([[[[False, False, False,  ..., False, False, False],
           [False, False, False,  ..., False, False, False],
           [False, False, False,  ..., False, False, False],
           ...,
           [False, False, False,  ..., False, False, False],
           [False, False, False,  ..., False, False, False],
           [False, False, False,  ..., False, False, False]]],
 
 
         [[[ True,  True,  True,  ...,  True,  True,  True],
           [ True,  True,  True,  ...,  True,  True,  True],
           [ True,  True,  True,  ...,  True,  True,  True],
           ...,
           [ True,  True,  True,  ...,  True,  True,  True],
           [ True,  True,  True,  ...,  True,  True,  True],
           [ True,  True,  True,  ...,  True,  True,  True]]],
 
 
         [[[False, False, False,  ..., False, False, False],
           [False, False, False,  ..., False, False, False],
           [False, False, False,  ..., False, False, False],
           ...,
           [False, False,

In [9]:
from enum import Enum
class ExplanationMetrics(str, Enum):
    # axiomatic
    COMPLETENESS = "completeness"
    MONOTONICITY_CORR_AND_NON_SENS = "monotonicity_corr_and_non_sens"

    # complexity
    EFFECTIVE_COMPLEXITY = "effective_complexity"
    COMPLEXITY = "complexity"
    SPARSENESS = "sparseness"

    # faithfulness
    FAITHFULNESS_CORR = "faithfulness_corr"
    FAITHFULNESS_ESTIMATE = "faithfulness_estimate"
    MONOTONICITY = "monotonicity"
    INFIDELITY = "infidelity"
    AOPC = "aopc"

    # robustness
    SENSITIVITY = "sensitivity"

len(ExplanationMetrics)
[e.value for e in ExplanationMetrics]

['completeness',
 'monotonicity_corr_and_non_sens',
 'effective_complexity',
 'complexity',
 'sparseness',
 'faithfulness_corr',
 'faithfulness_estimate',
 'monotonicity',
 'infidelity',
 'aopc',
 'sensitivity']

In [2]:
DOWNLOAD_FILES = [
    "data/test-00000-of-00001-9c204eb3f4e11791.parquet",
    "data/train-00000-of-00004-b4aaeceff1d90ecb.parquet",
    "data/train-00001-of-00004-7dbbe248962764c5.parquet",
    "data/train-00002-of-00004-688fe1305a55e5cc.parquet",
    "data/train-00003-of-00004-2d0cd200555ed7fd.parquet",
    "data/validation-00000-of-00001-cc3c5779fe22e8ca.parquet",
]
BASE_HF_REPO = "benjamin-paine/imagenet-1k-256x256"


Downloading readme: 100%|██████████| 88.1k/88.1k [00:00<00:00, 21.7MB/s]
Downloading data:   0%|          | 0/40 [00:09<?, ?files/s]


In [None]:
# Copyright 2022 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""CORD dataset"""


import dataclasses
import io
import json
import os
from pathlib import Path

import datasets
import numpy as np
import pandas as pd
import PIL
from huggingface_hub import hf_hub_download
from torchfusion.core.constants import DataKeys
from torchfusion.core.data.datasets.fusion_image_dataset import FusionImageDataset, FusionImageDatasetConfig
from torchfusion.core.data.datasets.fusion_ner_dataset import (
    FusionNERDataset,
    FusionNERDatasetConfig,
)
from torchfusion.core.data.text_utils.utilities import normalize_bbox
from torchfusion.core.utilities.logging import get_logger

logger = get_logger(__name__)

# Find for instance the citation on arxiv or on the dataset repo/website
_CITATION = """"""

# You can copy an official description
_DESCRIPTION = """CORD Dataset"""

_HOMEPAGE = "https://github.com/clovaai/cord"

_LICENSE = "Apache-2.0 license"

_NAMES = list(range(1000))
DOWNLOAD_FILES = [
]
BASE_HF_REPO = "benjamin-paine/imagenet-1k-256x256"


def download_files_from_hf(dataset_dir, files):
    for filename in files:
        hf_hub_download(
            BASE_HF_REPO,
            filename=filename,
            local_dir=Path(dataset_dir) / "huggingface" / "downloads",
            repo_type="dataset",
        )


def convert_to_list(row):
    for k, v in row.items():
        if isinstance(v, np.ndarray):
            row[k] = v.tolist()
            if isinstance(v[0], np.ndarray):
                row[k] = [x.tolist() for x in v]
    return row


@dataclasses.dataclass
class ImageNet1kConfig(FusionImageDatasetConfig):
    pass


class ImageNet1k(FusionImageDataset):
    VERSION = datasets.Version("1.0.0")

    BUILDER_CONFIGS = [
        ImageNet1kConfig(
            name="default",
            description=_DESCRIPTION,
            homepage=_HOMEPAGE,
            citation=_CITATION,
            license=_LICENSE,
            labels=_NAMES,
        ),
    ]

    def _split_generators(self, dl_manager):
        data_exists_in_dir = True
        for dir in ["train", "test", "dev"]:
            if not (Path(self.config.data_dir) / dir).exists():
                data_exists_in_dir = False
                logger.warning(
                    f"Data directory {self.config.data_dir}/{dir} does not exist. "
                    f"Trying to download dataset from the base repository: {BASE_HF_REPO}"
                )
                break

        if not data_exists_in_dir:
            self.load_data_from_parquet = True

            # prepare dataset from huggingface isntead. This load is v2 style
            download_files_from_hf(self.config.data_dir, DOWNLOAD_FILES)
            filepath = Path(self.config.data_dir) / "huggingface" / "downloads" / "data"

            # this loading is v0 style of cord
            return [
                datasets.SplitGenerator(
                    name=datasets.Split.TRAIN,
                    gen_kwargs={
                        "filepath": filepath,
                        "split": str(datasets.Split.TRAIN),
                    },
                ),
                datasets.SplitGenerator(
                    name=datasets.Split.TEST,
                    gen_kwargs={
                        "filepath": filepath,
                        "split": str(datasets.Split.TEST),
                    },
                ),
                datasets.SplitGenerator(
                    name=datasets.Split.VALIDATION,
                    gen_kwargs={
                        "filepath": filepath,
                        "split": str(datasets.Split.VALIDATION),
                    },
                ),
            ]

        else:
            self.load_data_from_parquet = False

            # this loading is v0 style of cord
            return [
                datasets.SplitGenerator(
                    name=datasets.Split.TRAIN,
                    gen_kwargs={"filepath": Path(self.config.data_dir) / "train"},
                ),
                datasets.SplitGenerator(
                    name=datasets.Split.TEST,
                    gen_kwargs={"filepath": Path(self.config.data_dir) / "test"},
                ),
                datasets.SplitGenerator(
                    name=datasets.Split.VALIDATION,
                    gen_kwargs={"filepath": Path(self.config.data_dir) / "dev"},
                ),
            ]

    def _generate_examples_impl(
        self,
        split,
        filepath,
    ):
        parquet_files = [f for f in os.listdir(filepath) if split in f]
        data = []
        for f in parquet_files:
            data_subset = pd.read_parquet(filepath / f)
            data_subset["image_path"] = data_subset.index.astype(str)
            data_subset["image_path"] = (
                str(filepath / f) + "_" + data_subset["image_path"]
            )
            data.append(data_subset)
        data = pd.concat(data, ignore_index=True)
        for idx, sample in enumerate(data):
            sample[DataKeys.INDEX] = idx
            sample[DataKeys.IMAGE] = PIL.Image.open(io.BytesIO(row["image"]["bytes"]))

            yield idx, sample

ImageNet1k()