In [3]:
# preprocess_images.py
import os
import numpy as np
import pandas as pd
import random
from PIL import Image
import torch
from torchvision import transforms
from torchvision.models import resnet50, vit_h_14
from torchvision.models import ResNet50_Weights, ViT_H_14_Weights
from tqdm import tqdm

# === Paths ===
working_dir = "/data/lodhar2/milan"
csv_path = os.path.join(working_dir, "data/split_balanced_dataset.csv")
output_dir = os.path.join(working_dir, "data")
os.makedirs(output_dir, exist_ok=True)

# === Reproducibility ===
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

seed_everything()

# === Load DataFrame ===
df = pd.read_csv(csv_path)

# === ViT Training Transform ===
vit_train_transform = transforms.Compose([
    transforms.Lambda(lambda img: img.crop((320, 0, 1600, 1080))),  # remove black bars
    transforms.RandomResizedCrop(518, scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# === ResNet Training Transform ===
resnet_train_transform = transforms.Compose([
    transforms.Lambda(lambda img: img.crop((320, 0, 1600, 1080))),  # remove black bars
    transforms.RandomResizedCrop((224, 224), scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# === ViT Validation Transform ===
vit_val_transform = transforms.Compose([
    transforms.Lambda(lambda img: img.crop((320, 0, 1600, 1080))),  # remove black bars
    transforms.Resize(518, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(518),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225]),
])

# === ResNet Validation Transform ===
resnet_val_transform = transforms.Compose([
    transforms.Lambda(lambda img: img.crop((320, 0, 1600, 1080))),  # remove black bars
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# === Helper function ===
def process_and_save(df_split, transform, filename, verbose_name):
    X = []
    y = []
    for _, row in tqdm(df_split.iterrows(), total=len(df_split), desc=f"Processing {verbose_name}"):
        try:
            img_path = os.path.join(working_dir, row["img_path"])
            img = Image.open(img_path).convert("RGB")
            tensor = transform(img)
            X.append(tensor.numpy())
            y.append(row["Class"])
        except Exception as e:
            print(f"Failed on {img_path}: {e}")
    
    X = np.stack(X)
    y = np.array(y)
    np.savez(os.path.join(output_dir, filename), images=X, labels=y)
    print(f"Saved {filename} with shape {X.shape} and {len(y)} labels")

# === Process splits ===
df_train = df[df["split"] == "train"]
df_val = df[df["split"] == "val"]

# # === ViT ===
# process_and_save(df_train, vit_train_transform, "vit_train.npz", "ViT Train")
# process_and_save(df_val, vit_val_transform, "vit_val.npz", "ViT Val")

# === ResNet ===
process_and_save(df_train, resnet_train_transform, "resnet_train.npz", "ResNet Train")
process_and_save(df_val, resnet_val_transform, "resnet_val.npz", "ResNet Val")

Processing ResNet Train: 100%|██████████| 398/398 [00:23<00:00, 16.74it/s]


Saved resnet_train.npz with shape (398, 3, 224, 224) and 398 labels


Processing ResNet Val: 100%|██████████| 221/221 [00:11<00:00, 18.56it/s]


Saved resnet_val.npz with shape (221, 3, 224, 224) and 221 labels
