In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import time
import torch.nn.functional as F
import torch.nn as nn
import matplotlib.pyplot as plt
from torchvision import models
from sklearn.metrics import precision_score, recall_score, f1_score
from tqdm import tqdm

from sklearn.metrics import precision_score, recall_score, f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [3]:
# Normalization values for CIFAR-10
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

trainset = torchvision.datasets.FashionMNIST(
    root='./data', train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testset = torchvision.datasets.FashionMNIST(
    root='./data', train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2
)


In [3]:
class Bottleneck(nn.Module):
    expansion = 4  # Multiplicative factor for output channels in the 3rd conv

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()

        # 1x1 conv to reduce channels
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1   = nn.BatchNorm2d(planes)

        # 3x3 conv
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(planes)

        # 1x1 conv to expand channels back by expansion
        self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
                               bias=False)
        self.bn3   = nn.BatchNorm2d(planes * self.expansion)

       
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes * self.expansion,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * self.expansion)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = F.relu(self.bn2(self.conv2(out)), inplace=True)
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out, inplace=True)
        return out


In [4]:
class ResNet164(nn.Module):
    def __init__(self, block=Bottleneck, num_blocks=[18, 18, 18], num_classes=10):
        super(ResNet164, self).__init__()
        self.in_planes = 16

        # Initial 3x3 convolution
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(16)

        
        self.layer1 = self._make_layer(block, planes=16,  num_blocks=num_blocks[0], stride=1)
        
        self.layer2 = self._make_layer(block, planes=32,  num_blocks=num_blocks[1], stride=2)
        
        self.layer3 = self._make_layer(block, planes=64,  num_blocks=num_blocks[2], stride=2)
        
        self.linear = nn.Linear(64 * block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)), inplace=True)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)

       
        out = F.avg_pool2d(out, out.size(3))  # or out.shape[2]
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [5]:
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, testloader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Accumulate loss
            running_loss += loss.item() * images.size(0)
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    test_loss = running_loss / total
    test_acc = 100.0 * correct / total

    return test_loss, test_acc

In [6]:
model = ResNet164().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
    model.parameters(),
    lr=0.1, momentum=0.9, weight_decay=1e-4
)


In [7]:
#transformations for metamorphic testing
transformations = [
    transforms.RandomRotation(degrees=10),  # Small rotation
    transforms.ColorJitter(brightness=0.2),  # Slight brightness change
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1))  # Small translation
]

In [8]:
import torch
import torch.nn.functional as F

def plurality_vote_prediction(model, images, transformations, device):
    """
    Given a batch of images and a list of transformations,
    compute:
    The original prediction f(x)
    The predictions for each transformed image f(T_i(x))
    Then, determine the mode (most frequent prediction) among the transformed predictions.
    If f(x) equals the mode, the prediction is accepted; otherwise, it is rejected.
    
    Returns:
      accepted_preds: a tensor of shape [batch_size] where accepted samples have their original prediction,
                      and rejected samples are marked with -1.
      accepted_mask: a boolean tensor of shape [batch_size] indicating which samples were accepted.
    """
    model.eval()
    
    # 1) Original predictions
    with torch.no_grad():
        original_out = model(images)
        _, original_preds = torch.max(F.softmax(original_out, dim=1), dim=1)  # shape [batch_size]
    
    # 2) Get predictions from each transformation (set S)
    transformed_preds_list = []
    for tf in transformations:
        # Apply the transformation on CPU (if required) and move back to device.
        x_tf = tf(images.cpu()).to(device)
        with torch.no_grad():
            out_tf = model(x_tf)
            _, preds_tf = torch.max(F.softmax(out_tf, dim=1), dim=1)
        transformed_preds_list.append(preds_tf)
    
    # 3) Compute the mode (most frequent prediction) over the transformed predictions.
    #    Note: If there are no transformations, we simply use the original predictions.
    if len(transformed_preds_list) > 0:
        stacked_preds = torch.stack(transformed_preds_list, dim=0)  # shape: [num_transforms, batch_size]
        mode_preds = stacked_preds.mode(dim=0).values  # shape: [batch_size]
    else:
        mode_preds = original_preds
    
    # 4) Determine acceptance: accepted if original_preds equals mode_preds.
    accepted_mask = (original_preds == mode_preds)
    
    # 5) Create an output: for accepted samples, return original prediction; for rejected, set to -1.
    accepted_preds = original_preds.clone()
    accepted_preds[~accepted_mask] = -1
    
    return accepted_preds, accepted_mask

In [9]:
from tqdm import tqdm

def test_with_plurality_vote(model, loader, transformations, device):
    """
    Evaluates the model on 'loader' using PluVot:
      - For each batch, it computes the plurality-voted predictions.
      - A sample is accepted if the original prediction f(x) equals the mode of the transformed predictions.
      - Rejected samples are marked with -1.
    
    It prints per-batch acceptance counts and returns:
      - The accuracy among accepted samples.
      - The overall acceptance rate.
    """
    model.eval()
    total_samples = 0
    total_accepted = 0
    correct_accepted = 0
    
    for i, (images, labels) in enumerate(tqdm(loader)):
        images, labels = images.to(device), labels.to(device)
        
        # Get PluVot predictions and acceptance mask for the batch.
        accepted_preds, accepted_mask = plurality_vote_prediction(model, images, transformations, device)
        
        batch_size = labels.size(0)
        total_samples += batch_size
        batch_accepted = accepted_mask.sum().item()
        total_accepted += batch_accepted
        
        # Calculate accuracy among accepted samples
        if batch_accepted > 0:
            batch_correct = (accepted_preds[accepted_mask] == labels[accepted_mask]).sum().item()
            correct_accepted += batch_correct
        
        print(f"Batch {i} | Accepted: {batch_accepted}/{batch_size}")
    
    overall_acceptance_rate = 100.0 * total_accepted / total_samples
    accepted_accuracy = 100.0 * correct_accepted / total_accepted if total_accepted > 0 else 0.0
    
    print(f"\nOverall Acceptance Rate: {overall_acceptance_rate:.2f}%")
    print(f"Accuracy among Accepted Samples: {accepted_accuracy:.2f}%")
    
    return accepted_accuracy, overall_acceptance_rate


In [10]:
accepted_acc, acceptance_rate = test_with_plurality_vote(model, testloader, transformations, device)
print(f"Plurality Vote Accuracy among accepted samples: {accepted_acc:.2f}%")
print(f"Overall Acceptance Rate: {acceptance_rate:.2f}%")

  2%|█▋                                                                                    | 2/100 [00:01<01:10,  1.39it/s]

Batch 0 | Accepted: 100/100
Batch 1 | Accepted: 100/100


  4%|███▍                                                                                  | 4/100 [00:01<00:32,  2.92it/s]

Batch 2 | Accepted: 100/100
Batch 3 | Accepted: 100/100


  6%|█████▏                                                                                | 6/100 [00:02<00:22,  4.27it/s]

Batch 4 | Accepted: 100/100
Batch 5 | Accepted: 100/100


  8%|██████▉                                                                               | 8/100 [00:02<00:17,  5.22it/s]

Batch 6 | Accepted: 100/100
Batch 7 | Accepted: 100/100


 10%|████████▌                                                                            | 10/100 [00:02<00:15,  5.82it/s]

Batch 8 | Accepted: 100/100
Batch 9 | Accepted: 100/100


 12%|██████████▏                                                                          | 12/100 [00:03<00:14,  6.17it/s]

Batch 10 | Accepted: 100/100
Batch 11 | Accepted: 100/100


 14%|███████████▉                                                                         | 14/100 [00:03<00:13,  6.34it/s]

Batch 12 | Accepted: 100/100
Batch 13 | Accepted: 100/100


 16%|█████████████▌                                                                       | 16/100 [00:03<00:13,  6.45it/s]

Batch 14 | Accepted: 100/100
Batch 15 | Accepted: 100/100


 18%|███████████████▎                                                                     | 18/100 [00:04<00:12,  6.49it/s]

Batch 16 | Accepted: 100/100
Batch 17 | Accepted: 100/100


 20%|█████████████████                                                                    | 20/100 [00:04<00:12,  6.46it/s]

Batch 18 | Accepted: 100/100
Batch 19 | Accepted: 100/100


 22%|██████████████████▋                                                                  | 22/100 [00:04<00:12,  6.39it/s]

Batch 20 | Accepted: 100/100
Batch 21 | Accepted: 100/100


 24%|████████████████████▍                                                                | 24/100 [00:05<00:11,  6.39it/s]

Batch 22 | Accepted: 100/100
Batch 23 | Accepted: 100/100


 26%|██████████████████████                                                               | 26/100 [00:05<00:11,  6.41it/s]

Batch 24 | Accepted: 100/100
Batch 25 | Accepted: 100/100


 28%|███████████████████████▊                                                             | 28/100 [00:05<00:11,  6.41it/s]

Batch 26 | Accepted: 100/100
Batch 27 | Accepted: 100/100


 30%|█████████████████████████▌                                                           | 30/100 [00:06<00:10,  6.40it/s]

Batch 28 | Accepted: 100/100
Batch 29 | Accepted: 100/100


 32%|███████████████████████████▏                                                         | 32/100 [00:06<00:10,  6.38it/s]

Batch 30 | Accepted: 100/100
Batch 31 | Accepted: 100/100


 34%|████████████████████████████▉                                                        | 34/100 [00:06<00:10,  6.38it/s]

Batch 32 | Accepted: 100/100
Batch 33 | Accepted: 100/100


 36%|██████████████████████████████▌                                                      | 36/100 [00:06<00:09,  6.47it/s]

Batch 34 | Accepted: 100/100
Batch 35 | Accepted: 100/100


 38%|████████████████████████████████▎                                                    | 38/100 [00:07<00:09,  6.50it/s]

Batch 36 | Accepted: 100/100
Batch 37 | Accepted: 100/100


 40%|██████████████████████████████████                                                   | 40/100 [00:07<00:09,  6.53it/s]

Batch 38 | Accepted: 100/100
Batch 39 | Accepted: 100/100


 42%|███████████████████████████████████▋                                                 | 42/100 [00:07<00:08,  6.55it/s]

Batch 40 | Accepted: 100/100
Batch 41 | Accepted: 100/100


 44%|█████████████████████████████████████▍                                               | 44/100 [00:08<00:08,  6.53it/s]

Batch 42 | Accepted: 100/100
Batch 43 | Accepted: 100/100


 46%|███████████████████████████████████████                                              | 46/100 [00:08<00:08,  6.53it/s]

Batch 44 | Accepted: 100/100
Batch 45 | Accepted: 100/100


 48%|████████████████████████████████████████▊                                            | 48/100 [00:08<00:07,  6.55it/s]

Batch 46 | Accepted: 100/100
Batch 47 | Accepted: 100/100


 50%|██████████████████████████████████████████▌                                          | 50/100 [00:09<00:07,  6.55it/s]

Batch 48 | Accepted: 100/100
Batch 49 | Accepted: 100/100


 52%|████████████████████████████████████████████▏                                        | 52/100 [00:09<00:07,  6.51it/s]

Batch 50 | Accepted: 100/100
Batch 51 | Accepted: 100/100


 54%|█████████████████████████████████████████████▉                                       | 54/100 [00:09<00:07,  6.52it/s]

Batch 52 | Accepted: 100/100
Batch 53 | Accepted: 100/100


 56%|███████████████████████████████████████████████▌                                     | 56/100 [00:10<00:06,  6.54it/s]

Batch 54 | Accepted: 100/100
Batch 55 | Accepted: 100/100


 58%|█████████████████████████████████████████████████▎                                   | 58/100 [00:10<00:06,  6.49it/s]

Batch 56 | Accepted: 100/100
Batch 57 | Accepted: 99/100


 60%|███████████████████████████████████████████████████                                  | 60/100 [00:10<00:06,  6.51it/s]

Batch 58 | Accepted: 100/100
Batch 59 | Accepted: 100/100


 62%|████████████████████████████████████████████████████▋                                | 62/100 [00:10<00:05,  6.51it/s]

Batch 60 | Accepted: 100/100
Batch 61 | Accepted: 100/100


 64%|██████████████████████████████████████████████████████▍                              | 64/100 [00:11<00:05,  6.50it/s]

Batch 62 | Accepted: 100/100
Batch 63 | Accepted: 100/100


 66%|████████████████████████████████████████████████████████                             | 66/100 [00:11<00:05,  6.51it/s]

Batch 64 | Accepted: 100/100
Batch 65 | Accepted: 100/100


 68%|█████████████████████████████████████████████████████████▊                           | 68/100 [00:11<00:04,  6.51it/s]

Batch 66 | Accepted: 100/100
Batch 67 | Accepted: 100/100


 70%|███████████████████████████████████████████████████████████▍                         | 70/100 [00:12<00:04,  6.53it/s]

Batch 68 | Accepted: 100/100
Batch 69 | Accepted: 100/100


 72%|█████████████████████████████████████████████████████████████▏                       | 72/100 [00:12<00:04,  6.50it/s]

Batch 70 | Accepted: 100/100
Batch 71 | Accepted: 100/100


 74%|██████████████████████████████████████████████████████████████▉                      | 74/100 [00:12<00:03,  6.51it/s]

Batch 72 | Accepted: 100/100
Batch 73 | Accepted: 100/100


 76%|████████████████████████████████████████████████████████████████▌                    | 76/100 [00:13<00:03,  6.49it/s]

Batch 74 | Accepted: 100/100
Batch 75 | Accepted: 100/100


 78%|██████████████████████████████████████████████████████████████████▎                  | 78/100 [00:13<00:03,  6.47it/s]

Batch 76 | Accepted: 100/100
Batch 77 | Accepted: 100/100


 80%|████████████████████████████████████████████████████████████████████                 | 80/100 [00:13<00:03,  6.49it/s]

Batch 78 | Accepted: 100/100
Batch 79 | Accepted: 100/100


 82%|█████████████████████████████████████████████████████████████████████▋               | 82/100 [00:14<00:02,  6.53it/s]

Batch 80 | Accepted: 100/100
Batch 81 | Accepted: 100/100


 84%|███████████████████████████████████████████████████████████████████████▍             | 84/100 [00:14<00:02,  6.50it/s]

Batch 82 | Accepted: 100/100
Batch 83 | Accepted: 100/100


 86%|█████████████████████████████████████████████████████████████████████████            | 86/100 [00:14<00:02,  6.52it/s]

Batch 84 | Accepted: 100/100
Batch 85 | Accepted: 100/100


 88%|██████████████████████████████████████████████████████████████████████████▊          | 88/100 [00:14<00:01,  6.54it/s]

Batch 86 | Accepted: 100/100
Batch 87 | Accepted: 100/100


 90%|████████████████████████████████████████████████████████████████████████████▌        | 90/100 [00:15<00:01,  6.52it/s]

Batch 88 | Accepted: 100/100
Batch 89 | Accepted: 100/100


 92%|██████████████████████████████████████████████████████████████████████████████▏      | 92/100 [00:15<00:01,  6.52it/s]

Batch 90 | Accepted: 100/100
Batch 91 | Accepted: 100/100


 94%|███████████████████████████████████████████████████████████████████████████████▉     | 94/100 [00:15<00:00,  6.53it/s]

Batch 92 | Accepted: 100/100
Batch 93 | Accepted: 100/100


 96%|█████████████████████████████████████████████████████████████████████████████████▌   | 96/100 [00:16<00:00,  6.54it/s]

Batch 94 | Accepted: 100/100
Batch 95 | Accepted: 100/100


 98%|███████████████████████████████████████████████████████████████████████████████████▎ | 98/100 [00:16<00:00,  6.51it/s]

Batch 96 | Accepted: 100/100
Batch 97 | Accepted: 100/100


100%|████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:16<00:00,  5.95it/s]

Batch 98 | Accepted: 100/100
Batch 99 | Accepted: 100/100

Overall Acceptance Rate: 99.99%
Accuracy among Accepted Samples: 10.00%
Plurality Vote Accuracy among accepted samples: 10.00%
Overall Acceptance Rate: 99.99%



