# Connect to Drive

In [1]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive')

# TODO: Enter the foldername in your Drive where you have saved the unzipped
# assignment folder, e.g. 'colab/cs231n/assignments/assignment3/'
FOLDERNAME = 'colab/pytorch-image-models/'
assert FOLDERNAME is not None, "[!] Enter the foldername."

# Now that we've mounted your Drive, this ensures that
# the Python interpreter of the Colab VM can load
# python files from within it.
import sys
sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))

%cd /content/drive/My\ Drive/$FOLDERNAME

Mounted at /content/drive
/content/drive/My Drive/colab/pytorch-image-models


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
!python --version

Python 3.10.12


# Create Metadata For ImageNet 1k

In [7]:
import os
import pandas as pd

def generate_metadata(root_dir, output_file):
    data = []
    for label_dir in os.listdir(root_dir):
        class_path = os.path.join(root_dir, label_dir)
        if os.path.isdir(class_path):
            for image_file in os.listdir(class_path):
                if image_file.endswith('.JPEG'):
                    data.append({'image_path': os.path.join(class_path, image_file), 'label': label_dir})

    df = pd.DataFrame(data)
    df.to_csv(output_file, index=False)
    print(f"Metadata file created at {output_file}")

In [8]:
root_dir = 'datasets/imagenet1k/val'
output_file = 'datasets/imagenet1k/val_metadata.csv'
generate_metadata(root_dir, output_file)

Metadata file created at datasets/imagenet1k/val_metadata.csv


In [9]:
root_dir = 'datasets/imagenet1k/train'
output_file = 'datasets/imagenet1k/train_metadata.csv'
generate_metadata(root_dir, output_file)

Metadata file created at datasets/imagenet1k/train_metadata.csv


# ImageNet Loader with Stratified Sampler

In [None]:
!pip install timm
!pip install ml-collections

In [None]:
# import pandas as pd

# metadata_path = 'datasets/imagenet1k/val_metadata.csv'
# df = pd.read_csv(metadata_path)

# Assuming 'label' column holds the label integers
# class_count = df['label'].value_counts().to_dict()
# class_count

In [None]:
# weights = [1.0 / class_count[label] for _, _, label in df.itertuples()]
# weights

In [35]:
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import timm
from torchvision import datasets, transforms

def create_stratified_loader(dataset, metadata_path, device=torch.device("cpu")):
    # Count the number of classes and instances per class
    df = pd.read_csv(metadata_path)
    class_count = df['label'].value_counts().to_dict()

    # Create weights for each instance
    weights = [1.0 / class_count[label] for _, _, label in df.itertuples()]
    sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)

    # Use timm's create_loader to integrate the sampler
    loader = create_loader(
        dataset,
        input_size=(3, 224, 224),
        batch_size=256,
        use_prefetcher=True,
        interpolation="bicubic",
        mean=(0.5, 0.5, 0.5),
        std=(0.5, 0.5, 0.5),
        num_workers=4,
        sampler=sampler,  # integrate our custom sampler
        crop_pct=0.9,
        crop_mode="center",
        crop_border_pixels=None,
        pin_memory=False,
        device=device,
        tf_preprocessing=False,
    )
    return loader

In [31]:
from timm.data import create_dataset, create_loader

In [27]:
dataset = create_dataset(
    root="datasets/imagenet1k/val",
    name="",
    split="validation",
    download=False,
    load_bytes=False,
    class_map=None,
    num_samples=None,
    input_key=None,
    input_img_mode="RGB",
    target_key=None,
)

In [36]:
# Create the stratified loader
loader = create_stratified_loader(dataset, 'datasets/imagenet1k/val_metadata.csv')

def batch_loader(max_batches=None):
    for i, (input, _) in enumerate(loader):
        if max_batches is not None and i >= max_batches:
            break
        yield {"x": input}



In [None]:
for batch in batch_loader(max_batches=1):
    print(batch)

In [39]:
len(pd.read_csv('datasets/imagenet1k/train_metadata.csv'))/256

934.5625

In [40]:
"pca_cache/%s_%s_stratified_batches-%s_{}_{}.pt" % ("vit", 25, 29)

'pca_cache/vit_25_stratified_batches-29_{}_{}.pt'

# Temp