In [1]:
from PIL import Image

from domainbed import datasets




In [2]:
DATASETS = [
    "VLCS",
    "PACS",
    # "OfficeHome",
    # "DomainNet",
]

In [3]:
import random
import torch
import argparse
from domainbed import hparams_registry
import imageio
import os
from tqdm import tqdm


In [5]:
dataset_name = 'VLCS'

In [6]:
data_dir = "domainbed/data"


In [7]:
hparams = hparams_registry.default_hparams('ERM', dataset_name)
dataset = datasets.get_dataset_class(dataset_name)(
    data_dir,
    list(range(datasets.num_environments(dataset_name))),
    hparams)

In [8]:
hparams

{'data_augmentation': True,
 'resnet18': False,
 'resnet_dropout': 0.0,
 'class_balanced': False,
 'nonlinear_classifier': False,
 'lr': 5e-05,
 'weight_decay': 0.0,
 'batch_size': 32}

In [9]:
print(dataset.ENVIRONMENTS)

['C', 'L', 'S', 'V']


In [10]:
print(dataset[-1].classes)

['bird', 'car', 'chair', 'dog', 'person']


In [11]:
len(dataset[0].samples)

1415

In [12]:
dataset[0].extra_repr

<bound method VisionDataset.extra_repr of Dataset ImageFolder
    Number of datapoints: 1415
    Root location: domainbed/data/VLCS/Caltech101
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=PIL.Image.BILINEAR)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )>

In [13]:
dataset[0].classes

['bird', 'car', 'chair', 'dog', 'person']

In [14]:
import pandas as pd


In [15]:
for dataset_name in DATASETS:
    hparams = hparams_registry.default_hparams('ERM', dataset_name)
    dataset = datasets.get_dataset_class(dataset_name)(
        data_dir,
        list(range(datasets.num_environments(dataset_name))),
        hparams)
    df = []

    for i, env in enumerate(dataset.ENVIRONMENTS):
        pbar = tqdm(dataset[i].samples, desc=f"Processing {dataset_name} {env}")
        for j, (x, y) in enumerate(dataset[i].samples):
            # verify image
            try:
                Image.open(x).convert("RGB").verify()
            except:
                # pbar message
                pbar.set_postfix_str(f"Invalid image: {x}")
                continue
            df.append({"path": x,
                       "label": y,
                       "env": env,
                       "env_id": i,
                       "class_name": dataset[i].classes[y],
                       "sample_id": j})
            pbar.update(1)
        pbar.close()
    df = pd.DataFrame(df)

    # train/val split based on both environment and label
    import sklearn.model_selection
    
    df["env_cls"] = df["env"] + "_" + df["class_name"]
    
    df["train"] = False
    
    for env_cls, group in df.groupby("env_cls"):
        train, val = sklearn.model_selection.train_test_split(group, test_size=0.2, stratify=group["label"])
        df.loc[val.index, "train"] = False
        df.loc[train.index, "train"] = True
    
    df.to_csv(f"{dataset_name}.csv", index=False)

Processing VLCS C: 100%|██████████████████| 1415/1415 [00:01<00:00, 1393.17it/s]
Processing VLCS L: 100%|████████████████████| 2656/2656 [01:28<00:00, 29.99it/s]
Processing VLCS S: 100%|███████████████████| 3282/3282 [00:12<00:00, 271.88it/s]
Processing VLCS V: 100%|███████████████████| 3376/3376 [00:05<00:00, 594.37it/s]
Processing PACS A: 100%|██████████████████| 2048/2048 [00:01<00:00, 1946.22it/s]
Processing PACS C: 100%|██████████████████| 2344/2344 [00:00<00:00, 2352.51it/s]
Processing PACS P: 100%|██████████████████| 1670/1670 [00:00<00:00, 1982.62it/s]
Processing PACS S: 100%|██████████████████| 3929/3929 [00:01<00:00, 2030.33it/s]


In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:

# # plot class distribution between environments
# for dataset_name in ['PACS', 'VLCS']:
#     df = pd.read_csv(f"{dataset_name}.csv")
#     plt.figure()
#     df.groupby("env")["label"].value_counts().unstack().plot(kind="bar", stacked=True)
#     plt.title(dataset_name)
#     plt.show()

In [None]:
# plot class distribution between train/val
for dataset_name in ['PACS', 'VLCS']:
    df = pd.read_csv(f"{dataset_name}.csv")
    plt.figure()
    df.groupby("train")["label"].value_counts().unstack().plot(kind="bar", stacked=True)
    plt.title(dataset_name)
    
    # save to file
    plt.savefig(f"report/{dataset_name}_train_val_class_distribution.png", bbox_inches='tight', pad_inches=0.1, dpi=300, transparent=True)
    
    plt.show()



In [None]:
# plot env distribution between train/val
for dataset_name in ['PACS', 'VLCS']:
    df = pd.read_csv(f"{dataset_name}.csv")
    plt.figure()
    df.groupby("train")["env"].value_counts().unstack().plot(kind="bar", stacked=True)
    plt.title(dataset_name)
    
    # save to file
    plt.savefig(f"report/{dataset_name}_train_val_env_distribution.png", bbox_inches='tight', pad_inches=0.1, dpi=300, transparent=True)