# 0. Colab housekeeping – GPU & Drive (optional)

In [None]:
!nvidia-smi -L        # confirm you got a GPU
from google.colab import drive
drive.mount("/content/drive")

# 1. Clone the repo

In [None]:
%cd /content
!git clone https://github.com/utkarsh231/ImageCompression vit-compression
%cd vit-compression

# 2. Install Python dependencies

In [None]:
!pip install -qr requirements.txt kagglehub

# 3. Set Kaggle credentials (for crawford/cat-dataset)

In [None]:
import json, os, textwrap, getpass, pathlib, pathlib
KAGGLE_JSON = textwrap.dedent("""
{
  "username":  "YOUR_KAGGLE_USERNAME",
  "key":       "YOUR_KAGGLE_API_KEY"
}
""").strip()

os.makedirs("/root/.kaggle", exist_ok=True)
with open("/root/.kaggle/kaggle.json", "w") as f:
    f.write(KAGGLE_JSON)
os.chmod("/root/.kaggle/kaggle.json", 0o600)

#  4. Quick sanity-check training run (10 epochs, tiny lr)
    Uses ImageNet sample tar shards shipped with WebDataset for a 2-minute smoke test.

In [None]:
!python train.py \
    data.dir=/content/vit-compression/sample_shards \
    data.val_dir=kaggle_cats \
    trainer.epochs=10 \
    trainer.lr=5e-4 \
    trainer.wandb=false \
    trainer.ckpt_dir=/content/drive/MyDrive/compression_ckpts

# (Expect loss, bpp, mse, ms_ssim, lpips logs every ~50 steps.)

# 5. Single-image inference

In [None]:
# Download a Kodak “barb” image for demo

# choose the last checkpoint that just finished (lam0.0015_e10.pt)
CKPT = "/content/drive/MyDrive/compression_ckpts/lam0.0015_e10.pt"

!python inference.py \
        --ckpt $CKPT \
        --img  barb.png

from IPython.display import Image, display
print("Input:")
display(Image("barb.png"))
print("Reconstruction:")
display(Image("barb.recon.png"))

In [6]:
from datasets import load_dataset
ds = load_dataset("timm/imagenet-w21-webp-wds", streaming=True, split="train")

sample = next(iter(ds))
print(sample.keys())        # → dict_keys(['__key__', 'webp', 'cls'])

dict_keys(['cls', 'json', 'webp', '__key__', '__url__'])


In [8]:
import io, json
from datasets import load_dataset
from PIL import Image
import torch
from torchvision.transforms import Compose, RandomResizedCrop, ToTensor, Normalize

# 1. stream the shards (no 850-GB download)
stream_ds = load_dataset(
    "timm/imagenet-w21-webp-wds",
    split="train",
    streaming=True,
)

# 2. torchvision-style transform
tfm = Compose([
    RandomResizedCrop(256),
    ToTensor(),
    Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])

def iter_batches(dataset, batch_size=16):
    images, labels = [], []
    for sample in dataset:
        # already a PIL Image
        img = sample["webp"].convert("RGB")   # ensure 3-channel RGB
        images.append(tfm(img))
        labels.append(int(sample["cls"]))

        if len(images) == batch_size:
            yield torch.stack(images), torch.tensor(labels)
            images, labels = [], []

# 3. quick smoke test
batch_iter = iter_batches(stream_ds, batch_size=8)
x, y = next(batch_iter)
print(x.shape, y[:5])
# -> torch.Size([8, 3, 256, 256]) tensor([1234,  593,  821, ...])

torch.Size([8, 3, 256, 256]) tensor([ 6355, 14407, 15592,  3377,  9226])


In [10]:
from huggingface_hub import list_repo_files
files = list_repo_files("timm/imagenet-w21-webp-wds", repo_type="dataset")
print(files[:10])          # peek at first few entries

['.gitattributes', 'LICENSE', 'README.md', '_info.json', '_info.yaml', 'imagenet_w21_webp-train-0000.tar', 'imagenet_w21_webp-train-0001.tar', 'imagenet_w21_webp-train-0002.tar', 'imagenet_w21_webp-train-0003.tar', 'imagenet_w21_webp-train-0004.tar']


In [12]:
from huggingface_hub import hf_hub_download
from pathlib import Path

dest = Path("data/imagenet21k_wds")
dest.mkdir(parents=True, exist_ok=True)

repo = "timm/imagenet-w21-webp-wds"
train_shards = 64        # pick 4096 to grab the full set
val_shards   = 64        # 64 validation shards exist

# ---- train ---------------------------------------------------------------
for i in range(train_shards):
    fname = f"imagenet_w21_webp-train-{i:04d}.tar"      # <-- 4 digits
    hf_hub_download(repo, fname, repo_type="dataset",
                    local_dir=dest, force_download=False)

# ---- validation ----------------------------------------------------------
for i in range(val_shards):
    fname = f"imagenet_w21_webp-validation-{i:04d}.tar"
    hf_hub_download(repo, fname, repo_type="dataset",
                    local_dir=dest, force_download=False)



OSError: [Errno 28] No space left on device

In [None]:
import webdataset as wds
from huggingface_hub import HfFileSystem, get_token, hf_hub_url

# Login using e.g. `huggingface-cli login` to access this dataset
fs = HfFileSystem()
files = [fs.resolve_path(path) for path in fs.glob("hf://datasets/timm/imagenet-w21-webp-wds/**/*-train-*.tar")]
urls = [hf_hub_url(file.repo_id, file.path_in_repo, repo_type="dataset") for file in files]
urls = f"pipe: curl -s -L -H 'Authorization:Bearer {get_token()}' {'::'.join(urls)}"

ds = wds.WebDataset(urls).decode()

  from .autonotebook import tqdm as notebook_tqdm
