In [2]:
!pip install lxml

Collecting lxml
  Downloading lxml-5.4.0-cp39-cp39-manylinux_2_28_x86_64.whl.metadata (3.5 kB)
Downloading lxml-5.4.0-cp39-cp39-manylinux_2_28_x86_64.whl (5.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m63.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: lxml
Successfully installed lxml-5.4.0


In [3]:
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
from lxml import etree

def parse_voc_xml(xml_path):
    tree = etree.parse(xml_path)
    root = tree.getroot()
    boxes, labels = [], []
    for obj in root.findall('object'):
        name = obj.find('name').text
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(name)
    return boxes, labels

class ImageNetCLSLDataset(Dataset):
    def __init__(self, root, split='train', transform=None):
        self.root = root
        self.split = split
        self.transform = transform
        
        if split == 'train':
            self.img_dir = os.path.join(root, 'Data/CLS-LOC/train')
            self.ann_dir = os.path.join(root, 'Annotations/CLS-LOC/train')
            self.samples = []
            for synset in os.listdir(self.img_dir):
                img_folder = os.path.join(self.img_dir, synset)
                ann_folder = os.path.join(self.ann_dir, synset)
                if not os.path.isdir(img_folder):
                    continue
                for fname in os.listdir(img_folder):
                    if fname.endswith('.JPEG'):
                        self.samples.append((
                            os.path.join(img_folder, fname),
                            os.path.join(ann_folder, fname.replace('.JPEG', '.xml')),
                            synset
                        ))
        elif split == 'val':
            self.img_dir = os.path.join(root, 'Data/CLS-LOC/val')
            self.ann_dir = os.path.join(root, 'Annotations/CLS-LOC/val')
            self.samples = []
            for fname in os.listdir(self.img_dir):
                if fname.endswith('.JPEG'):
                    self.samples.append((
                        os.path.join(self.img_dir, fname),
                        os.path.join(self.ann_dir, fname.replace('.JPEG', '.xml')),
                        None
                    ))
        else:
            raise ValueError("split must be 'train' or 'val'")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, ann_path, synset = self.samples[idx]
        img = Image.open(img_path).convert('RGB')
        boxes, labels = parse_voc_xml(ann_path)
        if self.transform:
            img = self.transform(img)
        return img, boxes, labels  # You can map labels to class indices here if desired


In [5]:
from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = ImageNetCLSLDataset('data/imagenet_dataset/ILSVRC', split='train', transform=transform)
val_dataset = ImageNetCLSLDataset('data/imagenet_dataset/ILSVRC', split='val', transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)


In [6]:
import os
from torch.utils.data import Dataset
from PIL import Image

class ImageNetTestDataset(Dataset):
    def __init__(self, root, transform=None):
        self.img_dir = os.path.join(root, 'Data/CLS-LOC/test')
        self.transform = transform
        self.img_names = [f for f in os.listdir(self.img_dir) if f.endswith('.JPEG')]
        self.img_names.sort()  # Optional: to have a deterministic order

    def __len__(self):
        return len(self.img_names)
    
    def __getitem__(self, idx):
        fname = self.img_names[idx]
        img_path = os.path.join(self.img_dir, fname)
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        # Return image and its ID (strip extension for compatibility with submission format)
        img_id = os.path.splitext(fname)[0]
        return img, img_id


In [9]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

test_dataset = ImageNetTestDataset('data/imagenet_dataset/ILSVRC', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)


In [10]:
in_channels = 3                # change to 3 if you use CIFAR10 dataset
image_size = 32                # change to 32 if you use CIFAR10 dataset
num_classes = 10

lr = 1e-3
batch_size = 64

patch_size = 4         # Each patch is 16x16, so 2x2 = 4 patches per image
hidden_dim = 256       # Token-mixing MLP hidden dim (formerly token_dim)
tokens_mlp_dim = 512    # Tokens MLP dim
channels_mlp_dim = 2048 # Channels MLP dim
num_blocks = 6         # Number of Mixer layers

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from MLP_Mixer import MLPMixer
from utils import train

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MLPMixer(in_channels=in_channels, embedding_dim=hidden_dim, num_classes=num_classes, patch_size=patch_size, image_size=image_size, depth=num_blocks, token_intermediate_dim=tokens_mlp_dim, channel_intermediate_dim=channels_mlp_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [16]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model, train_loader, val_loader, test_loader, 30, optimizer, criterion, False)

OSError: Caught OSError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/ediss_data/ediss4/sarosh/anaconda3/envs/saroshgpu/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/ediss_data/ediss4/sarosh/anaconda3/envs/saroshgpu/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/ediss_data/ediss4/sarosh/anaconda3/envs/saroshgpu/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_1114428/4111982920.py", line 64, in __getitem__
    boxes, labels = parse_voc_xml(ann_path)
  File "/tmp/ipykernel_1114428/4111982920.py", line 8, in parse_voc_xml
    tree = etree.parse(xml_path)
  File "src/lxml/etree.pyx", line 3590, in lxml.etree.parse
  File "src/lxml/parser.pxi", line 1958, in lxml.etree._parseDocument
  File "src/lxml/parser.pxi", line 1984, in lxml.etree._parseDocumentFromURL
  File "src/lxml/parser.pxi", line 1887, in lxml.etree._parseDocFromFile
  File "src/lxml/parser.pxi", line 1200, in lxml.etree._BaseParser._parseDocFromFile
  File "src/lxml/parser.pxi", line 633, in lxml.etree._ParserContext._handleParseResultDoc
  File "src/lxml/parser.pxi", line 743, in lxml.etree._handleParseResult
  File "src/lxml/parser.pxi", line 670, in lxml.etree._raiseParseError
OSError: Error reading file 'data/imagenet_dataset/ILSVRC/Annotations/CLS-LOC/train/n03956157/n03956157_9945.xml': failed to load "data/imagenet_dataset/ILSVRC/Annotations/CLS-LOC/train/n03956157/n03956157_9945.xml": No such file or directory
