In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
from utils import train
torch.cuda.empty_cache()

In [2]:
# Model definition
class CIFAR10_MLP(nn.Module):
    def __init__(self, input_dim=3*32*32, hidden_dims=[2048, 1024, 512, 256], num_classes=10, dropout=0.5):
        super().__init__()
        layers = []
        prev_dim = input_dim
        for hdim in hidden_dims:
            layers.append(nn.Linear(prev_dim, hdim))
            layers.append(nn.BatchNorm1d(hdim))
            layers.append(nn.ReLU(inplace=True))
            layers.append(nn.Dropout(dropout))
            prev_dim = hdim
        layers.append(nn.Linear(prev_dim, num_classes))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # flatten
        return self.net(x)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data transforms: normalize CIFAR-10 images
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.247, 0.243, 0.261))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), 
                         (0.247, 0.243, 0.261))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4)


Files already downloaded and verified
Files already downloaded and verified


In [4]:
model = CIFAR10_MLP().to(device)
model = torch.jit.script(model)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

In [5]:
train_metrics, val_metrics, test_metrics = train(model, train_loader, None, test_loader, 30, optimizer, criterion)


Epoch: 1 Total_Time: 1.6282 Average_Time_per_batch: 0.0042 Train_Accuracy: 0.3395 Train_Loss: 1.8272 
Epoch: 2 Total_Time: 1.4313 Average_Time_per_batch: 0.0037 Train_Accuracy: 0.4240 Train_Loss: 1.6124 
Epoch: 3 Total_Time: 1.4911 Average_Time_per_batch: 0.0038 Train_Accuracy: 0.4568 Train_Loss: 1.5268 
Epoch: 4 Total_Time: 1.3673 Average_Time_per_batch: 0.0035 Train_Accuracy: 0.4773 Train_Loss: 1.4695 
Epoch: 5 Total_Time: 1.4348 Average_Time_per_batch: 0.0037 Train_Accuracy: 0.4924 Train_Loss: 1.4291 
Epoch: 6 Total_Time: 1.5456 Average_Time_per_batch: 0.0040 Train_Accuracy: 0.5040 Train_Loss: 1.3967 
Epoch: 7 Total_Time: 1.4938 Average_Time_per_batch: 0.0038 Train_Accuracy: 0.5207 Train_Loss: 1.3581 
Epoch: 8 Total_Time: 1.5275 Average_Time_per_batch: 0.0039 Train_Accuracy: 0.5283 Train_Loss: 1.3309 
Epoch: 9 Total_Time: 1.5537 Average_Time_per_batch: 0.0040 Train_Accuracy: 0.5401 Train_Loss: 1.3053 
Epoch: 10 Total_Time: 1.5573 Average_Time_per_batch: 0.0040 Train_Accuracy: 0.545

In [6]:
from dpn_3.dpn import DPN as DPN_3
    
model_3 = DPN_3(3*32*32, 100, 10, True).cuda()
#model_3.compile()

In [7]:
model_3 = torch.jit.trace(model_3, torch.randn(128, 3*32*32).cuda())
optimizer = optim.Adam(model_3.parameters())
criterion = nn.CrossEntropyLoss()

In [8]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model_3, train_loader, None, test_loader, 30, optimizer, criterion)


Epoch: 1 Total_Time: 0.8701 Average_Time_per_batch: 0.0022 Train_Accuracy: 0.3900 Train_Loss: 1.7796 
Epoch: 2 Total_Time: 0.5746 Average_Time_per_batch: 0.0015 Train_Accuracy: 0.4561 Train_Loss: 1.6021 
Epoch: 3 Total_Time: 0.7166 Average_Time_per_batch: 0.0018 Train_Accuracy: 0.4825 Train_Loss: 1.5175 
Epoch: 4 Total_Time: 0.7614 Average_Time_per_batch: 0.0019 Train_Accuracy: 0.5026 Train_Loss: 1.4601 
Epoch: 5 Total_Time: 0.6123 Average_Time_per_batch: 0.0016 Train_Accuracy: 0.5191 Train_Loss: 1.4100 
Epoch: 6 Total_Time: 0.7444 Average_Time_per_batch: 0.0019 Train_Accuracy: 0.5350 Train_Loss: 1.3673 
Epoch: 7 Total_Time: 0.5066 Average_Time_per_batch: 0.0013 Train_Accuracy: 0.5467 Train_Loss: 1.3309 
Epoch: 8 Total_Time: 0.7800 Average_Time_per_batch: 0.0020 Train_Accuracy: 0.5559 Train_Loss: 1.2950 
Epoch: 9 Total_Time: 0.7000 Average_Time_per_batch: 0.0018 Train_Accuracy: 0.5681 Train_Loss: 1.2629 
Epoch: 10 Total_Time: 0.7348 Average_Time_per_batch: 0.0019 Train_Accuracy: 0.579

In [9]:
import torch
from torch import nn

hidden_dims = [2048, 1024, 512, 256, 10]
total = sum(hidden_dims)

blocks = len(hidden_dims)
features = 3 * 32 * 32
neural_blocks = []
for dim in hidden_dims:
    std_dev = torch.sqrt(torch.tensor(1 / features)).to(device)
    neural_blocks.append(torch.randn(dim, features).to(device) * std_dev)
    features += dim

feature_blocks = []
features_start = 0
for i in range(len(neural_blocks)):
    features_end = neural_blocks[i].shape[1]
    block = neural_blocks[i][:, features_start:]
    for j in range(i + 1, len(neural_blocks)):
        block = torch.cat((block, neural_blocks[j][:, features_start:features_end]), dim=0)
    feature_blocks.append(nn.Parameter(block))
    features_start = features_end

biases = biases = nn.Parameter(torch.empty(total).uniform_(0.0, 1.0)).to(device)

In [10]:
from dpn_2.dpn import DPN as DPN_2
    
model_2 = DPN_2(3*32*32, total, 10, True).cuda()
model_2.weights.extend(feature_blocks)
model_2.biases = biases

In [11]:
#model_3 = torch.jit.trace(model_3, torch.randn(128, 3*32*32).cuda())
optimizer = optim.Adam(model_2.parameters())
criterion = nn.CrossEntropyLoss()

In [12]:
from utils import train
train_metrics_3, val_metrics_3, test_metrics_3 = train(model_2, train_loader, None, test_loader, 30, optimizer, criterion)


Epoch: 1 Total_Time: 1.5930 Average_Time_per_batch: 0.0041 Train_Accuracy: 0.3827 Train_Loss: 2.1215 
Epoch: 2 Total_Time: 1.8245 Average_Time_per_batch: 0.0047 Train_Accuracy: 0.4491 Train_Loss: 1.7035 
Epoch: 3 Total_Time: 1.7723 Average_Time_per_batch: 0.0045 Train_Accuracy: 0.4753 Train_Loss: 1.6570 
Epoch: 4 Total_Time: 1.5580 Average_Time_per_batch: 0.0040 Train_Accuracy: 0.5011 Train_Loss: 1.5997 
Epoch: 5 Total_Time: 1.5248 Average_Time_per_batch: 0.0039 Train_Accuracy: 0.5266 Train_Loss: 1.5236 
Epoch: 6 Total_Time: 1.5118 Average_Time_per_batch: 0.0039 Train_Accuracy: 0.5463 Train_Loss: 1.4823 
Epoch: 7 Total_Time: 1.6712 Average_Time_per_batch: 0.0043 Train_Accuracy: 0.5641 Train_Loss: 1.4466 
Epoch: 8 Total_Time: 1.5555 Average_Time_per_batch: 0.0040 Train_Accuracy: 0.5808 Train_Loss: 1.4187 
Epoch: 9 Total_Time: 1.6176 Average_Time_per_batch: 0.0041 Train_Accuracy: 0.6018 Train_Loss: 1.3931 
Epoch: 10 Total_Time: 1.5681 Average_Time_per_batch: 0.0040 Train_Accuracy: 0.620

In [13]:
class ResNetMLP(nn.Module):
    def __init__(self, num_classes=10, mlp_hidden=256):
        super().__init__()
        # Load ResNet18 (pretrained=False for CIFAR10, since ImageNet weights use 224x224)
        backbone = torchvision.models.resnet18(weights=None)
        # Change input conv layer for CIFAR-10 (3x32x32)
        backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        backbone.maxpool = nn.Identity()  # Remove the first maxpool
        # Extract up to the last layer
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])  # Remove FC
        self.mlp = nn.Sequential(
            nn.Linear(512, mlp_hidden),
            nn.BatchNorm1d(mlp_hidden),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(mlp_hidden, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)  # [B, 512, 1, 1]
        x = x.view(x.size(0), -1)  # Flatten
        x = self.mlp(x)
        return x

In [14]:
# --- 1. Data Augmentation and Loading ---
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.247, 0.243, 0.261))
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.247, 0.243, 0.261))
])


In [15]:
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=100, shuffle=False, num_workers=4)


Files already downloaded and verified
Files already downloaded and verified


In [16]:
model = ResNetMLP(num_classes=10, mlp_hidden=256).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [17]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model, train_loader, None, test_loader, 30, optimizer, criterion, False)


Epoch: 1 Total_Time: 5.8375 Average_Time_per_batch: 0.0149 Train_Accuracy: 0.4787 Train_Loss: 1.4302 
Epoch: 2 Total_Time: 5.8193 Average_Time_per_batch: 0.0149 Train_Accuracy: 0.6587 Train_Loss: 0.9701 
Epoch: 3 Total_Time: 5.6937 Average_Time_per_batch: 0.0146 Train_Accuracy: 0.7331 Train_Loss: 0.7746 
Epoch: 4 Total_Time: 5.7803 Average_Time_per_batch: 0.0148 Train_Accuracy: 0.7775 Train_Loss: 0.6473 
Epoch: 5 Total_Time: 5.5938 Average_Time_per_batch: 0.0143 Train_Accuracy: 0.8079 Train_Loss: 0.5643 
Epoch: 6 Total_Time: 5.8141 Average_Time_per_batch: 0.0149 Train_Accuracy: 0.8296 Train_Loss: 0.5012 
Epoch: 7 Total_Time: 5.8533 Average_Time_per_batch: 0.0150 Train_Accuracy: 0.8456 Train_Loss: 0.4536 
Epoch: 8 Total_Time: 5.6534 Average_Time_per_batch: 0.0145 Train_Accuracy: 0.8598 Train_Loss: 0.4094 
Epoch: 9 Total_Time: 5.8281 Average_Time_per_batch: 0.0149 Train_Accuracy: 0.8721 Train_Loss: 0.3736 
Epoch: 10 Total_Time: 5.6774 Average_Time_per_batch: 0.0145 Train_Accuracy: 0.881

In [18]:
class ResNetDPN(nn.Module):
    def __init__(self, num_classes=10, mlp_hidden=256):
        super().__init__()
        # Load ResNet18 (pretrained=False for CIFAR10, since ImageNet weights use 224x224)
        backbone = torchvision.models.resnet18(weights=None)
        # Change input conv layer for CIFAR-10 (3x32x32)
        backbone.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        backbone.maxpool = nn.Identity()  # Remove the first maxpool
        # Extract up to the last layer
        self.backbone = nn.Sequential(*list(backbone.children())[:-1])  # Remove FC
        self.dpn = DPN_3(512, mlp_hidden + num_classes, num_classes, True).cuda()

    def forward(self, x):
        x = self.backbone(x)  # [B, 512, 1, 1]
        x = x.view(x.size(0), -1)  # Flatten
        x = self.dpn(x)
        return x

In [19]:
model = ResNetDPN(num_classes=10, mlp_hidden=256).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

In [20]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model, train_loader, None, test_loader, 30, optimizer, criterion, False)


Epoch: 1 Total_Time: 5.4124 Average_Time_per_batch: 0.0138 Train_Accuracy: 0.5099 Train_Loss: 1.3393 
Epoch: 2 Total_Time: 5.4383 Average_Time_per_batch: 0.0139 Train_Accuracy: 0.6894 Train_Loss: 0.8729 
Epoch: 3 Total_Time: 5.4434 Average_Time_per_batch: 0.0139 Train_Accuracy: 0.7609 Train_Loss: 0.6868 
Epoch: 4 Total_Time: 5.3822 Average_Time_per_batch: 0.0138 Train_Accuracy: 0.8021 Train_Loss: 0.5758 
Epoch: 5 Total_Time: 5.4027 Average_Time_per_batch: 0.0138 Train_Accuracy: 0.8240 Train_Loss: 0.5060 
Epoch: 6 Total_Time: 5.4370 Average_Time_per_batch: 0.0139 Train_Accuracy: 0.8443 Train_Loss: 0.4505 
Epoch: 7 Total_Time: 5.3597 Average_Time_per_batch: 0.0137 Train_Accuracy: 0.8572 Train_Loss: 0.4100 
Epoch: 8 Total_Time: 5.3766 Average_Time_per_batch: 0.0138 Train_Accuracy: 0.8735 Train_Loss: 0.3685 
Epoch: 9 Total_Time: 5.4630 Average_Time_per_batch: 0.0140 Train_Accuracy: 0.8837 Train_Loss: 0.3339 
Epoch: 10 Total_Time: 5.4242 Average_Time_per_batch: 0.0139 Train_Accuracy: 0.894

In [21]:
in_channels = 3                # change to 3 if you use CIFAR10 dataset
image_size = 32                # change to 32 if you use CIFAR10 dataset
num_classes = 10

lr = 1e-3
batch_size = 64

patch_size = 4         # Each patch is 16x16, so 2x2 = 4 patches per image
hidden_dim = 256       # Token-mixing MLP hidden dim (formerly token_dim)
tokens_mlp_dim = 512    # Tokens MLP dim
channels_mlp_dim = 2048 # Channels MLP dim
num_blocks = 6         # Number of Mixer layers

In [22]:
from MLP_Mixer import MLPMixer
model = MLPMixer(in_channels=in_channels, embedding_dim=hidden_dim, num_classes=num_classes, patch_size=patch_size, image_size=image_size, depth=num_blocks, token_intermediate_dim=tokens_mlp_dim, channel_intermediate_dim=channels_mlp_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [23]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model, train_loader, None, test_loader, 30, optimizer, criterion, False)


Epoch: 1 Total_Time: 5.8143 Average_Time_per_batch: 0.0149 Train_Accuracy: 0.3541 Train_Loss: 1.7426 
Epoch: 2 Total_Time: 5.8923 Average_Time_per_batch: 0.0151 Train_Accuracy: 0.5220 Train_Loss: 1.3224 
Epoch: 3 Total_Time: 5.8795 Average_Time_per_batch: 0.0150 Train_Accuracy: 0.5831 Train_Loss: 1.1611 
Epoch: 4 Total_Time: 5.6536 Average_Time_per_batch: 0.0145 Train_Accuracy: 0.6250 Train_Loss: 1.0460 
Epoch: 5 Total_Time: 5.8334 Average_Time_per_batch: 0.0149 Train_Accuracy: 0.6547 Train_Loss: 0.9683 
Epoch: 6 Total_Time: 5.8275 Average_Time_per_batch: 0.0149 Train_Accuracy: 0.6791 Train_Loss: 0.9019 
Epoch: 7 Total_Time: 5.8835 Average_Time_per_batch: 0.0150 Train_Accuracy: 0.7012 Train_Loss: 0.8452 
Epoch: 8 Total_Time: 5.8281 Average_Time_per_batch: 0.0149 Train_Accuracy: 0.7163 Train_Loss: 0.7949 
Epoch: 9 Total_Time: 5.3544 Average_Time_per_batch: 0.0137 Train_Accuracy: 0.7301 Train_Loss: 0.7606 
Epoch: 10 Total_Time: 5.7278 Average_Time_per_batch: 0.0146 Train_Accuracy: 0.749

In [24]:
from DPN_Mixer import MLPMixer as DPNMixer
model = DPNMixer(in_channels=in_channels, embedding_dim=hidden_dim, num_classes=num_classes, patch_size=patch_size, image_size=image_size, depth=num_blocks, token_intermediate_dim=tokens_mlp_dim, channel_intermediate_dim=channels_mlp_dim)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

In [25]:
torch.cuda.empty_cache()

'''if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs!")
    model = nn.DataParallel(model)  # Wrap the model in DataParallel'''

model = model.to(device)

In [26]:
train_metrics_3, val_metrics_3, test_metrics_3 = train(model, train_loader, None, test_loader, 20, optimizer, criterion, False)


Epoch: 1 Total_Time: 1.6306 Average_Time_per_batch: 0.0042 Train_Accuracy: 0.3585 Train_Loss: 1.7414 
Epoch: 2 Total_Time: 1.6806 Average_Time_per_batch: 0.0043 Train_Accuracy: 0.5029 Train_Loss: 1.3698 
Epoch: 3 Total_Time: 1.6416 Average_Time_per_batch: 0.0042 Train_Accuracy: 0.5422 Train_Loss: 1.2660 
Epoch: 4 Total_Time: 1.6444 Average_Time_per_batch: 0.0042 Train_Accuracy: 0.5618 Train_Loss: 1.2049 
Epoch: 5 Total_Time: 1.6157 Average_Time_per_batch: 0.0041 Train_Accuracy: 0.5816 Train_Loss: 1.1590 
Epoch: 6 Total_Time: 1.6627 Average_Time_per_batch: 0.0043 Train_Accuracy: 0.5980 Train_Loss: 1.1109 
Epoch: 7 Total_Time: 1.6787 Average_Time_per_batch: 0.0043 Train_Accuracy: 0.6097 Train_Loss: 1.0813 
Epoch: 8 Total_Time: 1.6217 Average_Time_per_batch: 0.0041 Train_Accuracy: 0.6231 Train_Loss: 1.0459 
Epoch: 9 Total_Time: 1.6613 Average_Time_per_batch: 0.0042 Train_Accuracy: 0.6361 Train_Loss: 1.0202 
Epoch: 10 Total_Time: 1.6438 Average_Time_per_batch: 0.0042 Train_Accuracy: 0.643