In [1]:
import os
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
from torchvision.transforms import InterpolationMode
from PIL import Image


In [2]:
# Path to the label mapping file and validation annotations
LABEL_MAPPING_FILE = "data/imagenet_dataset/LOC_synset_mapping.txt"
VAL_ANNOTATIONS_FILE = "data/imagenet_dataset/LOC_val_solution.csv"

# Load the label mapping from LOC_synset_mapping.txt
def load_label_mapping():
    wnid_to_idx = {}
    idx_to_wnid = {}
    with open(LABEL_MAPPING_FILE, 'r') as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            wnid = line.strip().split()[0]
            wnid_to_idx[wnid] = idx
            idx_to_wnid[idx] = wnid
    return wnid_to_idx, idx_to_wnid

# Load the validation annotations from LOC_val_solution.csv
def load_val_annotations():
    val_annotations = pd.read_csv(VAL_ANNOTATIONS_FILE)
    val_annotations['PredictionString'] = val_annotations['PredictionString'].apply(lambda x: x.split()[0])
    return val_annotations


In [None]:
class ImageNetDataset(Dataset):
    def __init__(self, data_dir, label_mapping, transform=None, is_train=True):
        self.data_dir = data_dir
        self.label_mapping = label_mapping
        self.transform = transform
        self.is_train = is_train
        self.image_paths = []
        self.labels = []
        
        # Load the train or validation data based on `is_train`
        if is_train:
            # Get all image paths and labels from train folder
            for wnid in os.listdir(data_dir):
                wnid_folder = os.path.join(data_dir, wnid)
                if os.path.isdir(wnid_folder):
                    count = 0
                    for img_file in os.listdir(wnid_folder):
                        if img_file.endswith('.JPEG'):
                            self.image_paths.append(os.path.join(wnid_folder, img_file))
                            self.labels.append(self.label_mapping[wnid])
                            count += 1
                            if count >= 1000:  # Limit to 1000 images per class
                                break
        else:
            # Get the validation images and their corresponding labels from LOC_val_solution.csv
            val_annotations = load_val_annotations()
            for _, row in val_annotations.iterrows():
                img_file = row['ImageId'] + '.JPEG'
                wnid = row['PredictionString']
                self.image_paths.append(os.path.join(data_dir, img_file))
                self.labels.append(self.label_mapping[wnid])
        
        self.labels = torch.tensor(self.labels)

        num_images = len(self.image_paths)
        self.images = torch.empty((num_images, 3, 224, 224), dtype=torch.float32)

        for idx, path in enumerate(self.image_paths):
            image = Image.open(path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            self.images[idx] = image

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        label = self.labels[idx]
        image = self.images[idx]

        return image, label


In [4]:
# Mean and std for ImageNet
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

train_transforms = T.Compose([
    T.RandomResizedCrop(224, interpolation=InterpolationMode.BICUBIC),
    T.RandomHorizontalFlip(p=0.5),
    T.RandAugment(num_ops=2, magnitude=9),  # can tune magnitude
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

val_transforms = T.Compose([
    T.Resize(256, interpolation=InterpolationMode.BICUBIC),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])


In [5]:
# Load the label mapping
wnid_to_idx, idx_to_wnid = load_label_mapping()

# Define the paths to the dataset
train_data_dir = "data/imagenet_dataset/ILSVRC/Data/CLS-LOC/train"
val_data_dir = "data/imagenet_dataset/ILSVRC/Data/CLS-LOC/val"

# Create the train and test datasets
train_dataset = ImageNetDataset(data_dir=train_data_dir, label_mapping=wnid_to_idx, transform=train_transforms, is_train=True)
val_dataset = ImageNetDataset(data_dir=val_data_dir, label_mapping=wnid_to_idx, transform=val_transforms, is_train=False)

In [6]:
# Create the train and test dataloaders
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=16, pin_memory= True, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True, persistent_workers=True)

In [7]:
# Example usage
'''for images, labels in train_loader:
    print(images.shape, labels.shape)

for images, labels in val_loader:
    print(images.shape, labels.shape)'''

'for images, labels in train_loader:\n    print(images.shape, labels.shape)\n\nfor images, labels in val_loader:\n    print(images.shape, labels.shape)'

In [8]:
in_channels = 3                # change to 3 if you use CIFAR10 dataset
image_size = 224                # change to 32 if you use CIFAR10 dataset
num_classes = 1000

lr = 1e-3

patch_size = 16         # Each patch is 16x16, so 14x14 = 196 patches per image
hidden_dim = 768       # Token-mixing MLP hidden dim (formerly token_dim)
tokens_mlp_dim = 384    # Tokens MLP dim
channels_mlp_dim = 3072 # Channels MLP dim
num_blocks = 12         # Number of Mixer layers

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from utils import train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [10]:
from MLP_Mixer import MLPMixer
model = MLPMixer(in_channels=in_channels, embedding_dim=hidden_dim, num_classes=num_classes, patch_size=patch_size, image_size=image_size, depth=num_blocks, token_intermediate_dim=tokens_mlp_dim, channel_intermediate_dim=channels_mlp_dim)
# If you have more than one GPU, wrap the model with DataParallel
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)  # Wrap the model for multi-GPU usage

# Move the model to the GPU
model = model.to(device)

Using 4 GPUs!


In [11]:
from torch.optim.lr_scheduler import CosineAnnealingLR

# Loss Function (with label smoothing)
class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropy, self).__init__()
        self.smoothing = smoothing

    def forward(self, x, target):
        log_probs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = -log_probs.gather(dim=-1, index=target.unsqueeze(1)).squeeze(1)
        smooth_loss = -log_probs.mean(dim=-1)
        return ((1 - self.smoothing) * nll_loss + self.smoothing * smooth_loss).mean()

In [12]:
criterion = LabelSmoothingCrossEntropy(smoothing=0.1)

# Optimizer
optimizer = optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.05)

In [None]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model, train_loader, val_loader, val_loader, 20, optimizer, criterion, False, device)


Epoch: 1 Total_Time: 10.9228 Average_Time_per_batch: 0.2731 Train_Accuracy: 0.0005 Train_Loss: 7.0185 Validation_Accuracy: 0.0011 Validation_Loss: 6.9354
Epoch: 2 Total_Time: 3.4567 Average_Time_per_batch: 0.0864 Train_Accuracy: 0.0015 Train_Loss: 6.9002 Validation_Accuracy: 0.0020 Validation_Loss: 6.9077

KeyboardInterrupt: 

: 

In [None]:
torch.cuda.empty_cache()

In [None]:
from DPN_Mixer import MLPMixer as DPNMixer
model = DPNMixer(in_channels=in_channels, embedding_dim=hidden_dim, num_classes=num_classes, patch_size=patch_size, image_size=image_size, depth=num_blocks, token_intermediate_dim=tokens_mlp_dim, channel_intermediate_dim=channels_mlp_dim)
# If you have more than one GPU, wrap the model with DataParallel
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)  # Wrap the model for multi-GPU usage

# Move the model to the GPU
model = model.to(device)


optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [None]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model, train_loader, val_loader, val_loader, 20, optimizer, criterion, False, device)