In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import lmdb
from PIL import Image
import io
import torchvision.transforms as T
import os
import numpy as np


In [None]:
class ImageNetLMDBDataset(Dataset):
    def __init__(self, lmdb_path, transform=None):
        self.env = lmdb.open(lmdb_path, readonly=True, lock=False)
        self.transform = transform
        
        # Load entire dataset into memory once
        with self.env.begin() as txn:
            self.length = sum(1 for key, _ in txn.cursor() if key.startswith(b'image-'))
            
            # Pre-allocate memory for images (assuming images are 3x224x224)
            self.images = np.empty((self.length, 3, 224, 224), dtype=np.uint8)  # Store as uint8 for efficient memory usage
            self.labels = np.empty(self.length, dtype=np.int32)

            for idx in range(self.length):
                key_img = f'image-{idx:09d}'.encode()
                key_label = f'label-{idx:09d}'.encode()

                img_bytes = txn.get(key_img)
                label_bytes = txn.get(key_label)

                img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
                img = np.array(img, dtype=np.uint8)  # Store image as uint8 numpy array
                self.images[idx] = img
                self.labels[idx] = int(label_bytes)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        
        if self.transform:
            img = self.transform(img)  # Transform (including conversion to float32, if necessary)
        
        return img, label

In [3]:
# Data transforms (basic example)
train_transform = T.Compose([
    T.RandomResizedCrop(224),
    T.RandomHorizontalFlip(),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

val_transform = T.Compose([
    T.Resize(256),
    T.CenterCrop(224),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

In [4]:
data_dir = 'data/imagenet_dataset'
# Initialize datasets
train_dataset = ImageNetLMDBDataset(os.path.join(data_dir, 'imagenet_train.lmdb'))
val_dataset = ImageNetLMDBDataset(os.path.join(data_dir, 'imagenet_val.lmdb'))


KeyboardInterrupt: 

In [5]:
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=16, pin_memory= True, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=16, pin_memory=True, persistent_workers=True)


In [6]:
# Example usage
for images, labels in train_loader:
    print(images.shape, labels.shape)

for images, labels in val_loader:
    print(images.shape, labels.shape)

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/ediss_data/ediss4/sarosh/anaconda3/envs/saroshgpu/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/ediss_data/ediss4/sarosh/anaconda3/envs/saroshgpu/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
  File "/ediss_data/ediss4/sarosh/anaconda3/envs/saroshgpu/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 398, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/ediss_data/ediss4/sarosh/anaconda3/envs/saroshgpu/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 211, in collate
    return [
  File "/ediss_data/ediss4/sarosh/anaconda3/envs/saroshgpu/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 212, in <listcomp>
    collate(samples, collate_fn_map=collate_fn_map)
  File "/ediss_data/ediss4/sarosh/anaconda3/envs/saroshgpu/lib/python3.9/site-packages/torch/utils/data/_utils/collate.py", line 240, in collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.Image.Image'>


In [7]:
in_channels = 3                # change to 3 if you use CIFAR10 dataset
image_size = 224                # change to 32 if you use CIFAR10 dataset
num_classes = 1000

lr = 1e-3

patch_size = 16         # Each patch is 16x16, so 14x14 = 196 patches per image
hidden_dim = 768       # Token-mixing MLP hidden dim (formerly token_dim)
tokens_mlp_dim = 384    # Tokens MLP dim
channels_mlp_dim = 3072 # Channels MLP dim
num_blocks = 12         # Number of Mixer layers

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from utils import train
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
print(device)

cuda


In [10]:
from MLP_Mixer import MLPMixer
model = MLPMixer(in_channels=in_channels, embedding_dim=hidden_dim, num_classes=num_classes, patch_size=patch_size, image_size=image_size, depth=num_blocks, token_intermediate_dim=tokens_mlp_dim, channel_intermediate_dim=channels_mlp_dim)
# If you have more than one GPU, wrap the model with DataParallel
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)  # Wrap the model for multi-GPU usage

# Move the model to the GPU
model = model.to(device)


optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

Using 4 GPUs!


In [11]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model, train_loader, val_loader, val_loader, 20, optimizer, criterion, False, device)

KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()

In [None]:
from DPN_Mixer import MLPMixer as DPNMixer
model = DPNMixer(in_channels=in_channels, embedding_dim=hidden_dim, num_classes=num_classes, patch_size=patch_size, image_size=image_size, depth=num_blocks, token_intermediate_dim=tokens_mlp_dim, channel_intermediate_dim=channels_mlp_dim)
# If you have more than one GPU, wrap the model with DataParallel
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)  # Wrap the model for multi-GPU usage

# Move the model to the GPU
model = model.to(device)


optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [None]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model, train_loader, val_loader, val_loader, 20, optimizer, criterion, False, device)