# New Image Classification Dataset

This notebook describes the process of preparing a user-provided image classification dataset (one not included in Hugging Face or Torchvision) for use in Armory Library.

In [None]:
import collections
from pathlib import Path

import datasets
from IPython.display import display
import pandas as pd

import armory.data
import armory.dataset
import armory.evaluation
import armory.examples

[SAMPLE Public Dataset](https://github.com/benjaminlewis-afrl/SAMPLE_dataset_public)

The SAMPLE dataset (Synthetic and Measured Paired Labeled Experiment) dataset consists of measured SAR imagery from the MSTAR collection (Moving and Stationary Target Acquisition and Recognition) paired with synthetic SAR imagery. The public version of this dataset contains data with azimuth angles between 10 and 80 degrees.

The MSTAR dataset contains SAR imagery of 10 types of military vehicles illustrated in the EO images below.

<img src="mstar_10_targets.png"
    alt="MSTAR 10 Targets"
    width="700">

[Song, Haibo & Ji, Kefeng & Zhang, Yunshu & Xing, Xiang & Zou, Huanxin. (2016). Sparse Representation-Based SAR Image Target Classification on the 10-Class MSTAR Data Set. Applied Sciences. 6. 26. 10.3390/app6010026.](https://www.mdpi.com/2076-3417/6/1/26)

## Download dataset

As a first step, we clone the [SAMPLE dataset repository](https://github.com/benjaminlewis-afrl/SAMPLE_dataset_public) that contains the real and synthetic SAR imagery into a temporary location.

In [None]:
tmp_dir = Path('/tmp')
sample_dir = tmp_dir / Path('SAMPLE_dataset_public')

In [None]:
%%bash -s $sample_dir

if [[ -d $1 ]]
then
    echo "$1 exists"
else
    git clone https://github.com/benjaminlewis-afrl/SAMPLE_dataset_public $1
fi

### Dataset structure

The SAMPLE dataset is organized according to the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.19.0/en/image_dataset#imagefolder) pattern. The imagery is split into two normalizations -- decibel and quarter power magnitude (QPM). For each normalization type, real and synthetic SAR imagery is partitioned into folders according to vehicle type.

In [None]:
!find $sample_dir -type d -not -path "$sample_dir/.git*" -not -path "$sample_dir/mat_files*" | sed -e "s/[^-][^\/]*\// |/g" -e "s/|\([^ ]\)/|-\1/"

## Load dataset

We load the QPM normalized, real SAR imagery data by calling [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/loading_methods#datasets.load_dataset) with the [`ImageFolder`](https://huggingface.co/docs/datasets/v2.19.0/en/image_dataset#imagefolder) dataset builder. `ImageFolder` automatically infers the class labels based on the directory names.

In [None]:
data_dir = sample_dir / Path("png_images", "qpm", "real")
raw_dataset = datasets.load_dataset('imagefolder', data_dir=data_dir)
raw_dataset

### Verify dataset

Check that image labels have been inferred correctly.

In [None]:
mstar_labels: list[str] = raw_dataset['train'].features['label'].names
mstar_labels

Since the SAR imagery is monochrome, we define a transform to convert the images to RBG format and apply it using the Hugging Face [`map`](https://huggingface.co/docs/datasets/v2.19.0/en/image_process#map) function that applies the transform over an entire dataset.

In [None]:
def transforms(examples):
    examples["image"] = [image.convert("RGB") for image in examples["image"]]
    return examples

raw_dataset = raw_dataset.map(transforms, batched=True)

Display a SAR image annotated with the image format and label.

In [None]:
mstar_example = raw_dataset['train'][0]

display(mstar_example['image'])
print(f"mode {mstar_example['image'].mode}")
print(f"label {mstar_labels[mstar_example['label']]}")

### Define train, validation and test splits

The `datasets.load_dataset` function creates a `train` split by default. By applying the [`datasets.Dataset.train_test_split`](https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.Dataset.train_test_split) method we can partition the dataset defined above into `train`, `valid` and `test` splits that are contained in a [`datasets.DatasetDict`](https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.DatasetDict).

In [None]:
train_dataset = raw_dataset['train'].train_test_split(
    test_size=3/10,
    stratify_by_column='label'
)

test_dataset = train_dataset['test'].train_test_split(
    test_size=2/3,
    stratify_by_column='label'
)

mstar_dataset = datasets.DatasetDict(
    {
        'train': train_dataset['train'],
        'valid': test_dataset['train'],
        'test': test_dataset['test']
    }
)

mstar_dataset

### Dataset statistics

Using the Hugging Face [`map`](https://huggingface.co/docs/datasets/v2.19.0/en/image_process#map) function that can apply a transform over an entire dataset, we can produce simple statistics that summarize the data. For example, the `count_labels` function accumulates counts per split of the number of objects of each category that are then used to create a Pandas dataframe.

In [None]:
def count_labels(ds: datasets.Dataset) -> list[int]:
    ctr: collections.Counter[str] = collections.Counter()
    
    def inc_label(l: int) -> None:
        ctr[mstar_labels[l]] += 1
        
    ds.map(inc_label, input_columns=['label'])
    counts = [ctr[l] for l in mstar_labels]
    return counts

df = pd.DataFrame(
    {split: count_labels(mstar_dataset[split]) for split in mstar_dataset.keys()},
    index=mstar_labels
)
df

A bar chart of the category counts clearly reveals the real MSTAR data is fairly balanced.

In [None]:
df.plot(kind='bar', 
        stacked=False, 
        title='MSTAR Class Counts') 

### Saving to Disk or Uploading to S3 Bucket

The new MSTAR dataset may be [saved to disk](https://huggingface.co/docs/datasets/v2.19.0/en/package_reference/main_classes#datasets.Dataset.save_to_disk) among other options.

In [None]:
mstar_path =  Path('mstar_10.hf')    
mstar_dataset.save_to_disk(mstar_path)
    
print("Loading the dataset")
print(datasets.load_from_disk(mstar_path))

## Integrate into Armory

Having imported a SAMPLE subset as a Hugging Face dataset, we are ready to plug our new dataset into the Armory Library framework. This consists of creating an `armory.dataset.ObjectDetectionDataLoader` that defines the underlying [PyTorch dataloader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader). Note that the `armory.data.Scale` object defines the type and scale of the data. The Armory dataloader is then wrapped by an `armory.evaluation.Dataset`.

In [None]:
batch_size = 16
shuffle = False

unnormalized_scale = armory.data.Scale(
    dtype=armory.data.DataType.UINT8,
    max=255,
)

mstar_dataloader = armory.dataset.ImageClassificationDataLoader(
    mstar_dataset['train'],
    dim=armory.data.ImageDimensions.CHW,
    scale=unnormalized_scale,
    image_key="image",
    label_key="label",
    batch_size=batch_size,
    shuffle=shuffle,
)

armory_dataset = armory.evaluation.Dataset(
    name="MSTAR-qpm-real",
    dataloader=mstar_dataloader,
)

armory_dataset

## Resources
- [SAMPLE Public Dataset](https://github.com/benjaminlewis-afrl/SAMPLE_dataset_public)

- [Lewis, B., Scarnati, T., Sudkamp, E., Nehrbass, J., Rosencrantz, S., & Zelnio, E. (2019, May). A SAR dataset for ATR development: the Synthetic and Measured Paired Labeled Experiment (SAMPLE). In Algorithms for Synthetic Aperture Radar Imagery XXVI (Vol. 10987, pp. 39-54). SPIE.](https://github.com/benjaminlewis-afrl/SAMPLE_dataset_public/blob/master/sample_public.pdf)
