In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
# Normalization values for CIFAR-10
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2
)


Files already downloaded and verified
Files already downloaded and verified


In [4]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut (projection) if shape changes (stride != 1 or channels differ)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out


In [5]:
class ResNet56(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super(ResNet56, self).__init__()
        self.in_channels = 16

        # Initial conv: 3x3, 16 filters
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        # Each layer: 9 blocks
        self.layer1 = self._make_layer(block, 16, blocks=9, stride=1)
        self.layer2 = self._make_layer(block, 32, blocks=9, stride=2)
        self.layer3 = self._make_layer(block, 64, blocks=9, stride=2)

        # Global average pool and final classification
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride):
        """Create a 'layer' consisting of `blocks` residual blocks."""
        strides = [stride] + [1]*(blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

In [5]:
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc


In [6]:
model = ResNet56().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

In [7]:
#transformations for metamorphic testing
transformations = [
    transforms.RandomRotation(degrees=10),  # Small rotation
    transforms.ColorJitter(brightness=0.2),  # Slight brightness change
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1))  # Small translation
]

In [8]:
reject_diffs = []

# Evaluate reject differential over the test set
for i, (images, labels) in enumerate(tqdm(testloader)):
    
    images = images.to(device)

    # Original prediction
    with torch.no_grad():
        original_output = model(images)
        original_prob = F.softmax(original_output, dim=1)               # shape [batch_size, 10]
        original_confidence, original_pred = torch.max(original_prob, 1) # shape [batch_size]

    # For metamorphic transform
    for tf in transformations:
       
        transformed_images = tf(images.cpu()).to(device)

        with torch.no_grad():
            transformed_output = model(transformed_images)
            transformed_prob = F.softmax(transformed_output, dim=1)
            transformed_confidence, transformed_pred = torch.max(transformed_prob, 1)

        # Compute absolute difference in confidence for each sample in the batch
    
        diff_tensor = torch.abs(original_confidence - transformed_confidence)

        # Convert to a single scalar by taking the mean over the batch
        diff_mean = diff_tensor.mean().item()

       
        reject_diffs.append(diff_mean)

       
        print(f"Batch {i} | Transform: {tf.__class__.__name__} | "
              f"Reject Differential (mean): {diff_mean:.4f}")


print("Number of recorded reject differentials:", len(reject_diffs))
print("Example of first few reject_diffs:", reject_diffs[:5])


  1%|█▋                                                                                                                                                                   | 1/100 [00:00<01:11,  1.39it/s]

Batch 0 | Transform: RandomRotation | Reject Differential (mean): 0.0166
Batch 0 | Transform: ColorJitter | Reject Differential (mean): 0.0294
Batch 0 | Transform: RandomAffine | Reject Differential (mean): 0.0156
Batch 1 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 1 | Transform: ColorJitter | Reject Differential (mean): 0.0282
Batch 1 | Transform: RandomAffine | Reject Differential (mean): 0.0175
Batch 2 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 2 | Transform: ColorJitter | Reject Differential (mean): 0.0275


  5%|████████▎                                                                                                                                                            | 5/100 [00:00<00:13,  6.93it/s]

Batch 2 | Transform: RandomAffine | Reject Differential (mean): 0.0188
Batch 3 | Transform: RandomRotation | Reject Differential (mean): 0.0166
Batch 3 | Transform: ColorJitter | Reject Differential (mean): 0.0295
Batch 3 | Transform: RandomAffine | Reject Differential (mean): 0.0197
Batch 4 | Transform: RandomRotation | Reject Differential (mean): 0.0181
Batch 4 | Transform: ColorJitter | Reject Differential (mean): 0.0228
Batch 4 | Transform: RandomAffine | Reject Differential (mean): 0.0191
Batch 5 | Transform: RandomRotation | Reject Differential (mean): 0.0188
Batch 5 | Transform: ColorJitter | Reject Differential (mean): 0.0235
Batch 5 | Transform: RandomAffine | Reject Differential (mean): 0.0180
Batch 6 | Transform: RandomRotation | Reject Differential (mean): 0.0000


  9%|██████████████▊                                                                                                                                                      | 9/100 [00:01<00:08, 11.35it/s]

Batch 6 | Transform: ColorJitter | Reject Differential (mean): 0.0265
Batch 6 | Transform: RandomAffine | Reject Differential (mean): 0.0200
Batch 7 | Transform: RandomRotation | Reject Differential (mean): 0.0164
Batch 7 | Transform: ColorJitter | Reject Differential (mean): 0.0287
Batch 7 | Transform: RandomAffine | Reject Differential (mean): 0.0210
Batch 8 | Transform: RandomRotation | Reject Differential (mean): 0.0174
Batch 8 | Transform: ColorJitter | Reject Differential (mean): 0.0298
Batch 8 | Transform: RandomAffine | Reject Differential (mean): 0.0196
Batch 9 | Transform: RandomRotation | Reject Differential (mean): 0.0162
Batch 9 | Transform: ColorJitter | Reject Differential (mean): 0.0260
Batch 9 | Transform: RandomAffine | Reject Differential (mean): 0.0205


 13%|█████████████████████▎                                                                                                                                              | 13/100 [00:01<00:06, 14.28it/s]

Batch 10 | Transform: RandomRotation | Reject Differential (mean): 0.0180
Batch 10 | Transform: ColorJitter | Reject Differential (mean): 0.0303
Batch 10 | Transform: RandomAffine | Reject Differential (mean): 0.0153
Batch 11 | Transform: RandomRotation | Reject Differential (mean): 0.0221
Batch 11 | Transform: ColorJitter | Reject Differential (mean): 0.0302
Batch 11 | Transform: RandomAffine | Reject Differential (mean): 0.0216
Batch 12 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 12 | Transform: ColorJitter | Reject Differential (mean): 0.0308
Batch 12 | Transform: RandomAffine | Reject Differential (mean): 0.0187
Batch 13 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 13 | Transform: ColorJitter | Reject Differential (mean): 0.0297
Batch 13 | Transform: RandomAffine | Reject Differential (mean): 0.0203


 17%|███████████████████████████▉                                                                                                                                        | 17/100 [00:01<00:05, 15.22it/s]

Batch 14 | Transform: RandomRotation | Reject Differential (mean): 0.0193
Batch 14 | Transform: ColorJitter | Reject Differential (mean): 0.0263
Batch 14 | Transform: RandomAffine | Reject Differential (mean): 0.0176
Batch 15 | Transform: RandomRotation | Reject Differential (mean): 0.0206
Batch 15 | Transform: ColorJitter | Reject Differential (mean): 0.0264
Batch 15 | Transform: RandomAffine | Reject Differential (mean): 0.0182
Batch 16 | Transform: RandomRotation | Reject Differential (mean): 0.0178
Batch 16 | Transform: ColorJitter | Reject Differential (mean): 0.0285
Batch 16 | Transform: RandomAffine | Reject Differential (mean): 0.0179
Batch 17 | Transform: RandomRotation | Reject Differential (mean): 0.0181


 21%|██████████████████████████████████▍                                                                                                                                 | 21/100 [00:01<00:04, 16.43it/s]

Batch 17 | Transform: ColorJitter | Reject Differential (mean): 0.0268
Batch 17 | Transform: RandomAffine | Reject Differential (mean): 0.0196
Batch 18 | Transform: RandomRotation | Reject Differential (mean): 0.0189
Batch 18 | Transform: ColorJitter | Reject Differential (mean): 0.0281
Batch 18 | Transform: RandomAffine | Reject Differential (mean): 0.0187
Batch 19 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 19 | Transform: ColorJitter | Reject Differential (mean): 0.0258
Batch 19 | Transform: RandomAffine | Reject Differential (mean): 0.0177
Batch 20 | Transform: RandomRotation | Reject Differential (mean): 0.0165
Batch 20 | Transform: ColorJitter | Reject Differential (mean): 0.0311
Batch 20 | Transform: RandomAffine | Reject Differential (mean): 0.0164


 25%|█████████████████████████████████████████                                                                                                                           | 25/100 [00:02<00:04, 16.92it/s]

Batch 21 | Transform: RandomRotation | Reject Differential (mean): 0.0220
Batch 21 | Transform: ColorJitter | Reject Differential (mean): 0.0319
Batch 21 | Transform: RandomAffine | Reject Differential (mean): 0.0192
Batch 22 | Transform: RandomRotation | Reject Differential (mean): 0.0172
Batch 22 | Transform: ColorJitter | Reject Differential (mean): 0.0274
Batch 22 | Transform: RandomAffine | Reject Differential (mean): 0.0186
Batch 23 | Transform: RandomRotation | Reject Differential (mean): 0.0166
Batch 23 | Transform: ColorJitter | Reject Differential (mean): 0.0267
Batch 23 | Transform: RandomAffine | Reject Differential (mean): 0.0196
Batch 24 | Transform: RandomRotation | Reject Differential (mean): 0.0186
Batch 24 | Transform: ColorJitter | Reject Differential (mean): 0.0242
Batch 24 | Transform: RandomAffine | Reject Differential (mean): 0.0201


 29%|███████████████████████████████████████████████▌                                                                                                                    | 29/100 [00:02<00:04, 17.39it/s]

Batch 25 | Transform: RandomRotation | Reject Differential (mean): 0.0158
Batch 25 | Transform: ColorJitter | Reject Differential (mean): 0.0225
Batch 25 | Transform: RandomAffine | Reject Differential (mean): 0.0156
Batch 26 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 26 | Transform: ColorJitter | Reject Differential (mean): 0.0272
Batch 26 | Transform: RandomAffine | Reject Differential (mean): 0.0219
Batch 27 | Transform: RandomRotation | Reject Differential (mean): 0.0181
Batch 27 | Transform: ColorJitter | Reject Differential (mean): 0.0258
Batch 27 | Transform: RandomAffine | Reject Differential (mean): 0.0226
Batch 28 | Transform: RandomRotation | Reject Differential (mean): 0.0156
Batch 28 | Transform: ColorJitter | Reject Differential (mean): 0.0244
Batch 28 | Transform: RandomAffine | Reject Differential (mean): 0.0159


 33%|██████████████████████████████████████████████████████                                                                                                              | 33/100 [00:02<00:03, 17.67it/s]

Batch 29 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 29 | Transform: ColorJitter | Reject Differential (mean): 0.0239
Batch 29 | Transform: RandomAffine | Reject Differential (mean): 0.0194
Batch 30 | Transform: RandomRotation | Reject Differential (mean): 0.0152
Batch 30 | Transform: ColorJitter | Reject Differential (mean): 0.0320
Batch 30 | Transform: RandomAffine | Reject Differential (mean): 0.0185
Batch 31 | Transform: RandomRotation | Reject Differential (mean): 0.0177
Batch 31 | Transform: ColorJitter | Reject Differential (mean): 0.0236
Batch 31 | Transform: RandomAffine | Reject Differential (mean): 0.0204
Batch 32 | Transform: RandomRotation | Reject Differential (mean): 0.0195
Batch 32 | Transform: ColorJitter | Reject Differential (mean): 0.0257
Batch 32 | Transform: RandomAffine | Reject Differential (mean): 0.0187


 37%|████████████████████████████████████████████████████████████▋                                                                                                       | 37/100 [00:02<00:03, 17.80it/s]

Batch 33 | Transform: RandomRotation | Reject Differential (mean): 0.0175
Batch 33 | Transform: ColorJitter | Reject Differential (mean): 0.0235
Batch 33 | Transform: RandomAffine | Reject Differential (mean): 0.0195
Batch 34 | Transform: RandomRotation | Reject Differential (mean): 0.0191
Batch 34 | Transform: ColorJitter | Reject Differential (mean): 0.0347
Batch 34 | Transform: RandomAffine | Reject Differential (mean): 0.0188
Batch 35 | Transform: RandomRotation | Reject Differential (mean): 0.0214
Batch 35 | Transform: ColorJitter | Reject Differential (mean): 0.0281
Batch 35 | Transform: RandomAffine | Reject Differential (mean): 0.0211
Batch 36 | Transform: RandomRotation | Reject Differential (mean): 0.0225
Batch 36 | Transform: ColorJitter | Reject Differential (mean): 0.0263
Batch 36 | Transform: RandomAffine | Reject Differential (mean): 0.0192


 41%|███████████████████████████████████████████████████████████████████▏                                                                                                | 41/100 [00:03<00:03, 17.87it/s]

Batch 37 | Transform: RandomRotation | Reject Differential (mean): 0.0186
Batch 37 | Transform: ColorJitter | Reject Differential (mean): 0.0228
Batch 37 | Transform: RandomAffine | Reject Differential (mean): 0.0184
Batch 38 | Transform: RandomRotation | Reject Differential (mean): 0.0138
Batch 38 | Transform: ColorJitter | Reject Differential (mean): 0.0296
Batch 38 | Transform: RandomAffine | Reject Differential (mean): 0.0145
Batch 39 | Transform: RandomRotation | Reject Differential (mean): 0.0175
Batch 39 | Transform: ColorJitter | Reject Differential (mean): 0.0202
Batch 39 | Transform: RandomAffine | Reject Differential (mean): 0.0190
Batch 40 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 40 | Transform: ColorJitter | Reject Differential (mean): 0.0263
Batch 40 | Transform: RandomAffine | Reject Differential (mean): 0.0199


 45%|█████████████████████████████████████████████████████████████████████████▊                                                                                          | 45/100 [00:03<00:03, 17.88it/s]

Batch 41 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 41 | Transform: ColorJitter | Reject Differential (mean): 0.0290
Batch 41 | Transform: RandomAffine | Reject Differential (mean): 0.0210
Batch 42 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 42 | Transform: ColorJitter | Reject Differential (mean): 0.0281
Batch 42 | Transform: RandomAffine | Reject Differential (mean): 0.0180
Batch 43 | Transform: RandomRotation | Reject Differential (mean): 0.0076
Batch 43 | Transform: ColorJitter | Reject Differential (mean): 0.0318
Batch 43 | Transform: RandomAffine | Reject Differential (mean): 0.0181
Batch 44 | Transform: RandomRotation | Reject Differential (mean): 0.0171
Batch 44 | Transform: ColorJitter | Reject Differential (mean): 0.0294
Batch 44 | Transform: RandomAffine | Reject Differential (mean): 0.0195


 49%|████████████████████████████████████████████████████████████████████████████████▎                                                                                   | 49/100 [00:03<00:02, 17.80it/s]

Batch 45 | Transform: RandomRotation | Reject Differential (mean): 0.0184
Batch 45 | Transform: ColorJitter | Reject Differential (mean): 0.0339
Batch 45 | Transform: RandomAffine | Reject Differential (mean): 0.0174
Batch 46 | Transform: RandomRotation | Reject Differential (mean): 0.0171
Batch 46 | Transform: ColorJitter | Reject Differential (mean): 0.0289
Batch 46 | Transform: RandomAffine | Reject Differential (mean): 0.0181
Batch 47 | Transform: RandomRotation | Reject Differential (mean): 0.0206
Batch 47 | Transform: ColorJitter | Reject Differential (mean): 0.0259
Batch 47 | Transform: RandomAffine | Reject Differential (mean): 0.0202
Batch 48 | Transform: RandomRotation | Reject Differential (mean): 0.0199
Batch 48 | Transform: ColorJitter | Reject Differential (mean): 0.0283
Batch 48 | Transform: RandomAffine | Reject Differential (mean): 0.0221


 53%|██████████████████████████████████████████████████████████████████████████████████████▉                                                                             | 53/100 [00:03<00:02, 17.74it/s]

Batch 49 | Transform: RandomRotation | Reject Differential (mean): 0.0185
Batch 49 | Transform: ColorJitter | Reject Differential (mean): 0.0285
Batch 49 | Transform: RandomAffine | Reject Differential (mean): 0.0000
Batch 50 | Transform: RandomRotation | Reject Differential (mean): 0.0213
Batch 50 | Transform: ColorJitter | Reject Differential (mean): 0.0271
Batch 50 | Transform: RandomAffine | Reject Differential (mean): 0.0218
Batch 51 | Transform: RandomRotation | Reject Differential (mean): 0.0192
Batch 51 | Transform: ColorJitter | Reject Differential (mean): 0.0302
Batch 51 | Transform: RandomAffine | Reject Differential (mean): 0.0178
Batch 52 | Transform: RandomRotation | Reject Differential (mean): 0.0196
Batch 52 | Transform: ColorJitter | Reject Differential (mean): 0.0277
Batch 52 | Transform: RandomAffine | Reject Differential (mean): 0.0187


 57%|█████████████████████████████████████████████████████████████████████████████████████████████▍                                                                      | 57/100 [00:03<00:02, 17.35it/s]

Batch 53 | Transform: RandomRotation | Reject Differential (mean): 0.0181
Batch 53 | Transform: ColorJitter | Reject Differential (mean): 0.0292
Batch 53 | Transform: RandomAffine | Reject Differential (mean): 0.0156
Batch 54 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 54 | Transform: ColorJitter | Reject Differential (mean): 0.0280
Batch 54 | Transform: RandomAffine | Reject Differential (mean): 0.0169
Batch 55 | Transform: RandomRotation | Reject Differential (mean): 0.0192
Batch 55 | Transform: ColorJitter | Reject Differential (mean): 0.0309
Batch 55 | Transform: RandomAffine | Reject Differential (mean): 0.0196
Batch 56 | Transform: RandomRotation | Reject Differential (mean): 0.0184
Batch 56 | Transform: ColorJitter | Reject Differential (mean): 0.0258
Batch 56 | Transform: RandomAffine | Reject Differential (mean): 0.0184


 61%|████████████████████████████████████████████████████████████████████████████████████████████████████                                                                | 61/100 [00:04<00:02, 17.85it/s]

Batch 57 | Transform: RandomRotation | Reject Differential (mean): 0.0157
Batch 57 | Transform: ColorJitter | Reject Differential (mean): 0.0271
Batch 57 | Transform: RandomAffine | Reject Differential (mean): 0.0196
Batch 58 | Transform: RandomRotation | Reject Differential (mean): 0.0134
Batch 58 | Transform: ColorJitter | Reject Differential (mean): 0.0231
Batch 58 | Transform: RandomAffine | Reject Differential (mean): 0.0162
Batch 59 | Transform: RandomRotation | Reject Differential (mean): 0.0219
Batch 59 | Transform: ColorJitter | Reject Differential (mean): 0.0293
Batch 59 | Transform: RandomAffine | Reject Differential (mean): 0.0213
Batch 60 | Transform: RandomRotation | Reject Differential (mean): 0.0174
Batch 60 | Transform: ColorJitter | Reject Differential (mean): 0.0315
Batch 60 | Transform: RandomAffine | Reject Differential (mean): 0.0182


 63%|███████████████████████████████████████████████████████████████████████████████████████████████████████▎                                                            | 63/100 [00:04<00:02, 17.97it/s]

Batch 61 | Transform: RandomRotation | Reject Differential (mean): 0.0143
Batch 61 | Transform: ColorJitter | Reject Differential (mean): 0.0305
Batch 61 | Transform: RandomAffine | Reject Differential (mean): 0.0195
Batch 62 | Transform: RandomRotation | Reject Differential (mean): 0.0181
Batch 62 | Transform: ColorJitter | Reject Differential (mean): 0.0297
Batch 62 | Transform: RandomAffine | Reject Differential (mean): 0.0176
Batch 63 | Transform: RandomRotation | Reject Differential (mean): 0.0173
Batch 63 | Transform: ColorJitter | Reject Differential (mean): 0.0296
Batch 63 | Transform: RandomAffine | Reject Differential (mean): 0.0208
Batch 64 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 64 | Transform: ColorJitter | Reject Differential (mean): 0.0218


 67%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                                      | 67/100 [00:04<00:01, 17.51it/s]

Batch 64 | Transform: RandomAffine | Reject Differential (mean): 0.0188
Batch 65 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 65 | Transform: ColorJitter | Reject Differential (mean): 0.0285
Batch 65 | Transform: RandomAffine | Reject Differential (mean): 0.0200
Batch 66 | Transform: RandomRotation | Reject Differential (mean): 0.0172
Batch 66 | Transform: ColorJitter | Reject Differential (mean): 0.0326
Batch 66 | Transform: RandomAffine | Reject Differential (mean): 0.0184
Batch 67 | Transform: RandomRotation | Reject Differential (mean): 0.0145
Batch 67 | Transform: ColorJitter | Reject Differential (mean): 0.0244
Batch 67 | Transform: RandomAffine | Reject Differential (mean): 0.0180


 71%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                               | 71/100 [00:04<00:01, 17.08it/s]

Batch 68 | Transform: RandomRotation | Reject Differential (mean): 0.0184
Batch 68 | Transform: ColorJitter | Reject Differential (mean): 0.0243
Batch 68 | Transform: RandomAffine | Reject Differential (mean): 0.0211
Batch 69 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 69 | Transform: ColorJitter | Reject Differential (mean): 0.0236
Batch 69 | Transform: RandomAffine | Reject Differential (mean): 0.0182
Batch 70 | Transform: RandomRotation | Reject Differential (mean): 0.0182
Batch 70 | Transform: ColorJitter | Reject Differential (mean): 0.0264
Batch 70 | Transform: RandomAffine | Reject Differential (mean): 0.0157
Batch 71 | Transform: RandomRotation | Reject Differential (mean): 0.0195
Batch 71 | Transform: ColorJitter | Reject Differential (mean): 0.0268
Batch 71 | Transform: RandomAffine | Reject Differential (mean): 0.0160


 75%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                         | 75/100 [00:04<00:01, 17.30it/s]

Batch 72 | Transform: RandomRotation | Reject Differential (mean): 0.0179
Batch 72 | Transform: ColorJitter | Reject Differential (mean): 0.0301
Batch 72 | Transform: RandomAffine | Reject Differential (mean): 0.0187
Batch 73 | Transform: RandomRotation | Reject Differential (mean): 0.0164
Batch 73 | Transform: ColorJitter | Reject Differential (mean): 0.0258
Batch 73 | Transform: RandomAffine | Reject Differential (mean): 0.0154
Batch 74 | Transform: RandomRotation | Reject Differential (mean): 0.0208
Batch 74 | Transform: ColorJitter | Reject Differential (mean): 0.0261
Batch 74 | Transform: RandomAffine | Reject Differential (mean): 0.0199
Batch 75 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 75 | Transform: ColorJitter | Reject Differential (mean): 0.0275
Batch 75 | Transform: RandomAffine | Reject Differential (mean): 0.0166


 79%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                  | 79/100 [00:05<00:01, 17.45it/s]

Batch 76 | Transform: RandomRotation | Reject Differential (mean): 0.0170
Batch 76 | Transform: ColorJitter | Reject Differential (mean): 0.0253
Batch 76 | Transform: RandomAffine | Reject Differential (mean): 0.0160
Batch 77 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 77 | Transform: ColorJitter | Reject Differential (mean): 0.0288
Batch 77 | Transform: RandomAffine | Reject Differential (mean): 0.0196
Batch 78 | Transform: RandomRotation | Reject Differential (mean): 0.0160
Batch 78 | Transform: ColorJitter | Reject Differential (mean): 0.0273
Batch 78 | Transform: RandomAffine | Reject Differential (mean): 0.0177
Batch 79 | Transform: RandomRotation | Reject Differential (mean): 0.0176
Batch 79 | Transform: ColorJitter | Reject Differential (mean): 0.0263
Batch 79 | Transform: RandomAffine | Reject Differential (mean): 0.0202


 83%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                            | 83/100 [00:05<00:01, 16.58it/s]

Batch 80 | Transform: RandomRotation | Reject Differential (mean): 0.0160
Batch 80 | Transform: ColorJitter | Reject Differential (mean): 0.0294
Batch 80 | Transform: RandomAffine | Reject Differential (mean): 0.0196
Batch 81 | Transform: RandomRotation | Reject Differential (mean): 0.0177
Batch 81 | Transform: ColorJitter | Reject Differential (mean): 0.0293
Batch 81 | Transform: RandomAffine | Reject Differential (mean): 0.0168
Batch 82 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 82 | Transform: ColorJitter | Reject Differential (mean): 0.0266
Batch 82 | Transform: RandomAffine | Reject Differential (mean): 0.0226
Batch 83 | Transform: RandomRotation | Reject Differential (mean): 0.0000


 85%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                        | 85/100 [00:05<00:00, 15.75it/s]

Batch 83 | Transform: ColorJitter | Reject Differential (mean): 0.0289
Batch 83 | Transform: RandomAffine | Reject Differential (mean): 0.0193
Batch 84 | Transform: RandomRotation | Reject Differential (mean): 0.0176
Batch 84 | Transform: ColorJitter | Reject Differential (mean): 0.0276
Batch 84 | Transform: RandomAffine | Reject Differential (mean): 0.0164
Batch 85 | Transform: RandomRotation | Reject Differential (mean): 0.0182
Batch 85 | Transform: ColorJitter | Reject Differential (mean): 0.0275
Batch 85 | Transform: RandomAffine | Reject Differential (mean): 0.0197
Batch 86 | Transform: RandomRotation | Reject Differential (mean): 0.0172
Batch 86 | Transform: ColorJitter | Reject Differential (mean): 0.0288


 89%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                  | 89/100 [00:05<00:00, 16.58it/s]

Batch 86 | Transform: RandomAffine | Reject Differential (mean): 0.0174
Batch 87 | Transform: RandomRotation | Reject Differential (mean): 0.0200
Batch 87 | Transform: ColorJitter | Reject Differential (mean): 0.0290
Batch 87 | Transform: RandomAffine | Reject Differential (mean): 0.0201
Batch 88 | Transform: RandomRotation | Reject Differential (mean): 0.0198
Batch 88 | Transform: ColorJitter | Reject Differential (mean): 0.0301
Batch 88 | Transform: RandomAffine | Reject Differential (mean): 0.0188
Batch 89 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 89 | Transform: ColorJitter | Reject Differential (mean): 0.0267
Batch 89 | Transform: RandomAffine | Reject Differential (mean): 0.0224
Batch 90 | Transform: RandomRotation | Reject Differential (mean): 0.0000


 93%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌           | 93/100 [00:06<00:00, 17.01it/s]

Batch 90 | Transform: ColorJitter | Reject Differential (mean): 0.0276
Batch 90 | Transform: RandomAffine | Reject Differential (mean): 0.0204
Batch 91 | Transform: RandomRotation | Reject Differential (mean): 0.0182
Batch 91 | Transform: ColorJitter | Reject Differential (mean): 0.0271
Batch 91 | Transform: RandomAffine | Reject Differential (mean): 0.0188
Batch 92 | Transform: RandomRotation | Reject Differential (mean): 0.0173
Batch 92 | Transform: ColorJitter | Reject Differential (mean): 0.0278
Batch 92 | Transform: RandomAffine | Reject Differential (mean): 0.0192
Batch 93 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 93 | Transform: ColorJitter | Reject Differential (mean): 0.0226
Batch 93 | Transform: RandomAffine | Reject Differential (mean): 0.0210


 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████     | 97/100 [00:06<00:00, 17.15it/s]

Batch 94 | Transform: RandomRotation | Reject Differential (mean): 0.0000
Batch 94 | Transform: ColorJitter | Reject Differential (mean): 0.0255
Batch 94 | Transform: RandomAffine | Reject Differential (mean): 0.0207
Batch 95 | Transform: RandomRotation | Reject Differential (mean): 0.0156
Batch 95 | Transform: ColorJitter | Reject Differential (mean): 0.0301
Batch 95 | Transform: RandomAffine | Reject Differential (mean): 0.0180
Batch 96 | Transform: RandomRotation | Reject Differential (mean): 0.0186
Batch 96 | Transform: ColorJitter | Reject Differential (mean): 0.0267
Batch 96 | Transform: RandomAffine | Reject Differential (mean): 0.0166
Batch 97 | Transform: RandomRotation | Reject Differential (mean): 0.0178
Batch 97 | Transform: ColorJitter | Reject Differential (mean): 0.0236


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.44it/s]

Batch 97 | Transform: RandomAffine | Reject Differential (mean): 0.0188
Batch 98 | Transform: RandomRotation | Reject Differential (mean): 0.0196
Batch 98 | Transform: ColorJitter | Reject Differential (mean): 0.0267
Batch 98 | Transform: RandomAffine | Reject Differential (mean): 0.0201
Batch 99 | Transform: RandomRotation | Reject Differential (mean): 0.0190
Batch 99 | Transform: ColorJitter | Reject Differential (mean): 0.0299
Batch 99 | Transform: RandomAffine | Reject Differential (mean): 0.0224
Number of recorded reject differentials: 300
Example of first few reject_diffs: [0.016567975282669067, 0.029449326917529106, 0.015580164268612862, 0.0, 0.028186988085508347]





In [11]:
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc


In [12]:
train_losses = []
test_losses = []

num_epochs = 10

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, trainloader, criterion, optimizer)
    test_loss, test_acc = evaluate(model, testloader, criterion)

    # Append losses to track
    train_losses.append(train_loss)
    test_losses.append(test_loss)

    print(f"Epoch [{epoch+1}/{num_epochs}]")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"  Test  Loss: {test_loss:.4f}, Test  Acc: {test_acc:.2f}%")

Epoch [1/10]
  Train Loss: 1.9900, Train Acc: 26.17%
  Test  Loss: 1.6644, Test  Acc: 38.49%
Epoch [2/10]
  Train Loss: 1.5954, Train Acc: 40.75%
  Test  Loss: 1.5084, Test  Acc: 44.93%
Epoch [3/10]
  Train Loss: 1.3474, Train Acc: 51.05%
  Test  Loss: 1.3064, Test  Acc: 53.22%
Epoch [4/10]
  Train Loss: 1.0923, Train Acc: 60.93%
  Test  Loss: 1.0881, Test  Acc: 63.14%
Epoch [5/10]
  Train Loss: 0.8739, Train Acc: 69.33%
  Test  Loss: 0.8101, Test  Acc: 72.19%
Epoch [6/10]
  Train Loss: 0.7360, Train Acc: 74.36%
  Test  Loss: 0.8194, Test  Acc: 72.40%
Epoch [7/10]
  Train Loss: 0.6523, Train Acc: 77.52%
  Test  Loss: 0.7859, Test  Acc: 74.14%
Epoch [8/10]
  Train Loss: 0.5937, Train Acc: 79.39%
  Test  Loss: 0.6970, Test  Acc: 76.56%
Epoch [9/10]
  Train Loss: 0.5485, Train Acc: 81.08%
  Test  Loss: 0.5401, Test  Acc: 81.47%
Epoch [10/10]
  Train Loss: 0.5092, Train Acc: 82.56%
  Test  Loss: 0.5279, Test  Acc: 82.33%


In [13]:
from sklearn.metrics import precision_score, recall_score, f1_score

def validate(model, testloader, device, criterion):
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    all_preds = []
    all_targets = []
    
    # Switch off gradient tracking for validation
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            
            # Forward pass
            output = model(data)
            loss = criterion(output, target)
            
            # Accumulate loss * batch_size for accurate average
            val_running_loss += loss.item() * data.size(0)
            
            # Predictions
            _, preds = torch.max(output, 1)
            val_running_correct += (preds == target).sum().item()
            
            # Store for metrics
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    # Calculate average loss
    total_samples = len(testloader.dataset)
    val_loss = val_running_loss / total_samples
    
    # Calculate accuracy
    val_accuracy = 100.0 * val_running_correct / total_samples
    
    # Calculate precision, recall, F1 (weighted or macro—your choice)
    precision = precision_score(all_targets, all_preds, average='weighted')
    recall    = recall_score(all_targets, all_preds, average='weighted')
    f1        = f1_score(all_targets, all_preds, average='weighted')
    
    return val_loss, val_accuracy, precision, recall, f1


In [14]:
model_save_path = 'resnet56_RejDiff_cifar10.pth'
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

Model saved to resnet56_RejDiff_cifar10.pth


In [15]:
model.load_state_dict(torch.load(model_save_path))
mode = model.to(device)

  model.load_state_dict(torch.load(model_save_path))


In [16]:
final_val_loss, final_val_accuracy, final_precision, final_recall, final_f1 = validate( model,  testloader, device=device, criterion=criterion)

print(f"\nFinal Test Results - Loss: {final_val_loss:.4f}, "f"Accuracy: {final_val_accuracy:.2f}%, "f"Precision: {final_precision:.2f}, "f"Recall: {final_recall:.2f}, "f"F1 Score: {final_f1:.2f}")


Final Test Results - Loss: 0.5279, Accuracy: 82.33%, Precision: 0.82, Recall: 0.82, F1 Score: 0.82
