In [None]:
import pandas as pd
import numpy as np
import torch
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt

In [None]:
def get_image_dataloaders(
    data_dir,
    batch_size=32,
    img_size=224,
    val_split=False,
    val_ratio=0.2,
    shuffle=True,
    seed=42
):
    """
    Loads image data from a directory and returns train and test DataLoaders.
    If data_dir contains only one folder with class subfolders, it splits into train/test.
    If 'train' and 'test' subfolders exist, it loads them directly.
    If 'val' exists, it is merged with train.
    """
    transform = transforms.Compose([
        transforms.Resize((img_size, img_size)),
        transforms.ToTensor(),
    ])

    # Check for 'train' and 'test' subfolders
    train_path = os.path.join(data_dir, 'train')
    test_path = os.path.join(data_dir, 'test')
    val_path = os.path.join(data_dir, 'val')

    if os.path.isdir(train_path) and os.path.isdir(test_path):
        train_dataset = datasets.ImageFolder(train_path, transform=transform)
        test_dataset = datasets.ImageFolder(test_path, transform=transform)
        # If val exists, merge with train
        if os.path.isdir(val_path):
            val_dataset = datasets.ImageFolder(val_path, transform=transform)
            train_dataset.samples += val_dataset.samples
            train_dataset.targets += val_dataset.targets
    else:
        # Assume data_dir contains class subfolders directly
        full_dataset = datasets.ImageFolder(data_dir, transform=transform)
        n_total = len(full_dataset)
        n_test = int(val_ratio * n_total)
        n_train = n_total - n_test
        train_dataset, test_dataset = random_split(
            full_dataset, [n_train, n_test],
            generator=torch.Generator().manual_seed(seed)
        )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader


In [None]:
train_loader, test_loader = get_image_dataloaders('../data/food-101')

In [None]:
train_loader

In [None]:
def show_one_image_from_loader(loader, class_names=None):
    """
    Displays one image and its label from a given DataLoader.
    Args:
        loader: PyTorch DataLoader (train_loader, val_loader, or test_loader)
        class_names: Optional list of class names to display instead of label index
    """
    images, labels = next(iter(loader))
    img = images[0]
    label = labels[0]
    img_np = img.numpy().transpose(1, 2, 0)
    plt.imshow(img_np)
    if class_names:
        plt.title(f"Label: {class_names[label.item()]}")
    else:
        plt.title(f"Label: {label.item()}")
    plt.axis('off')
    plt.show()

# Example usage:
show_one_image_from_loader(train_loader)
show_one_image_from_loader(test_loader)

In [None]:
# Get class names from the train dataset
train_dataset = train_loader.dataset
test_dataset = test_loader.dataset

print("Train classes:", train_dataset.classes)
print("Test classes: ", test_dataset.classes)

# Check that all have the same number of classes
print("Train classes:", len(train_dataset.classes))
print("Test classes: ", len(test_dataset.classes))

# Check that class-to-index mapping is the same
print("Train class_to_idx:", train_dataset.class_to_idx)
print("Test class_to_idx: ", test_dataset.class_to_idx)

In [None]:
def update_master_classes(data_dir, master_file='./src/classes/master_classes.txt'):
    # Read existing classes
    if os.path.exists(master_file):
        with open(master_file, 'r') as f:
            master_classes = set(line.strip() for line in f)
    else:
        master_classes = set()
    # Find all class folders in new data
    new_classes = set(os.listdir(data_dir))
    # Update master list
    updated = master_classes | new_classes
    with open(master_file, 'w') as f:
        for cls in sorted(updated):
            f.write(f"{cls}\n")
    print(f"Updated master_classes.txt with {len(updated)} classes.")


In [None]:
101000/2683