In [None]:
import os, jax, torch, torchvision
from jax import numpy as jnp
from torch.utils.data import Dataset, DataLoader, default_collate

In [None]:
class Imagenet(Dataset):
    def __init__(self, data_path="./data/Data/CLS-LOC/train"):
        g = os.walk(data_path, topdown=True)
        self.classes = next(g)[1]
        self.paths = [os.path.join(dirname, f) for (dirname, _, filenames) in g for f in filenames]

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        fp = self.paths[idx]
        img = torchvision.io.read_image(fp).float() / 255.
        img = torch.swapaxes(img, 0, -1)
        if img.shape[0]<128 or img.shape[1]<128 or img.shape[-1]!=3:
            del self.paths[idx]
            os.remove(fp)
            return self.__getitem__(idx)
        label = self.classes.index(fp.split("/")[-2].rstrip(".JPEG"))
        label = torch.eye(1000)[label].float()
        return img, label

def jax_collate(batch):
    globals()["batch"] = batch
    # Convert to jax
    imgs, labels = zip(*batch)
    imgs = jax.tree.map(jnp.asarray, imgs)
    labels = jax.tree.map(jnp.asarray, labels)
    # Find minimum height and width in this batch
    min_height = jax.tree.reduce(lambda x, y: min(x.shape[0], y.shape[0]), imgs)
    min_width = jax.tree.reduce(lambda x, y: min(x.shape[1], y.shape[1]), imgs)
    # Resize images to the minimum height and width
    batch = [torchvision.transforms.functional.resize(img, (min_height, min_width)) for img in imgs]
    # Concat
    imgs = jnp.stack(imgs)
    labels = jnp.stack(labels)
    # Convert to jax
    return imgs, labels

In [None]:
ds = DataLoader(Imagenet(), batch_size=64, shuffle=True, collate_fn=jax_collate)

In [None]:
for x, y in ds:
    print(x.shape, y.shape)
    break

In [None]:
imgs, labels = zip(*batch)
# Find minimum height and width in this batch
min_height = min(img.shape[0] for img in imgs)
min_width = min(img.shape[1] for img in imgs)
# Resize images to the minimum height and width
imgs = [torchvision.transforms.functional.resize(img, (min_height, min_width)) for img in imgs]
# Concat
imgs = jnp.stack([jnp.asarray(img) for img in imgs])
labels = jnp.stack([jnp.asarray(label) for label in labels])