In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
# Normalization values for CIFAR-10
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2
)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut (projection) if shape changes (stride != 1 or channels differ)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out


In [4]:
class ResNet56(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super(ResNet56, self).__init__()
        self.in_channels = 16

        # Initial conv: 3x3, 16 filters
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        # Each layer: 9 blocks
        self.layer1 = self._make_layer(block, 16, blocks=9, stride=1)
        self.layer2 = self._make_layer(block, 32, blocks=9, stride=2)
        self.layer3 = self._make_layer(block, 64, blocks=9, stride=2)

        # Global average pool and final classification
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride):
        """Create a 'layer' consisting of `blocks` residual blocks."""
        strides = [stride] + [1]*(blocks - 1)
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out

In [5]:
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc


In [6]:
model = ResNet56().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

In [7]:
#transformations for metamorphic testing
transformations = [
    transforms.RandomRotation(degrees=10),  # Small rotation
    transforms.ColorJitter(brightness=0.2),  # Slight brightness change
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1))  # Small translation
]

In [8]:
import torch
import torch.nn.functional as F

def weighted_vote_prediction(model, images, transformations, weights, device, num_classes=10):
    """
      weighted_preds: Tensor of shape [batch_size] with the weighted vote result.
      accepted_mask: Boolean tensor of shape [batch_size] indicating accepted samples.
    """
    model.eval()

    # 1) Original predictions
    with torch.no_grad():
        original_out = model(images)
        _, original_preds = torch.max(F.softmax(original_out, dim=1), dim=1)  # shape: [batch_size]

    # Initialize weighted vote sum using one-hot encoding of original predictions and its weight.
    weighted_sum = F.one_hot(original_preds, num_classes=num_classes).float() * weights[0]

    # 2) For each metamorphic transformation, get predictions and add weighted one-hot votes.
    for i, tf in enumerate(transformations):
        # Apply transformation on CPU and then move to device.
        x_tf = tf(images.cpu()).to(device)
        with torch.no_grad():
            out_tf = model(x_tf)
            _, preds_tf = torch.max(F.softmax(out_tf, dim=1), dim=1)
        # Convert predictions to one-hot and multiply by weight.
        weighted_sum += F.one_hot(preds_tf, num_classes=num_classes).float() * weights[i+1]

    # 3) Final weighted vote: for each sample, take the argmax of the weighted sum.
    weighted_preds = weighted_sum.argmax(dim=1)  # shape: [batch_size]

    # 4) Determine acceptance: accept if weighted vote equals the original prediction.
    accepted_mask = (weighted_preds == original_preds)

    return weighted_preds, accepted_mask


In [9]:
from tqdm import tqdm

def test_with_weighted_vote(model, loader, transformations, weights, device):
    """
    Evaluates the model on the test set 'loader' using weighted voting.
    For each batch:
      - Computes weighted vote predictions (using original + transformed outputs).
      - Accepts a sample only if the weighted vote equals the original prediction.
    
    Prints per-batch information and returns:
      - The accuracy among accepted samples.
      - The overall acceptance rate.
    """
    model.eval()
    total_samples = 0
    total_accepted = 0
    correct_accepted = 0

    for i, (images, labels) in enumerate(tqdm(loader)):
        images, labels = images.to(device), labels.to(device)

        # Get weighted vote predictions and acceptance mask for the batch.
        weighted_preds, accepted_mask = weighted_vote_prediction(model, images, transformations, weights, device)

        batch_size = labels.size(0)
        total_samples += batch_size
        batch_accepted = accepted_mask.sum().item()
        total_accepted += batch_accepted

        # Compute correct predictions among accepted samples.
        if batch_accepted > 0:
            batch_correct = (weighted_preds[accepted_mask] == labels[accepted_mask]).sum().item()
            correct_accepted += batch_correct

        print(f"Batch {i} | Accepted: {batch_accepted}/{batch_size}")

    overall_acceptance_rate = 100.0 * total_accepted / total_samples
    accepted_accuracy = 100.0 * correct_accepted / total_accepted if total_accepted > 0 else 0.0

    print(f"\nOverall Acceptance Rate: {overall_acceptance_rate:.2f}%")
    print(f"Accuracy among Accepted Samples: {accepted_accuracy:.2f}%")
    
    return accepted_accuracy, overall_acceptance_rate

In [10]:
weights = [0.5, 0.25, 0.25, 0.25]

acc_weighted, acceptance_rate = test_with_weighted_vote(model, testloader, transformations, weights, device)
print(f"Weighted Vote Accuracy among accepted samples: {acc_weighted:.2f}%")
print(f"Overall Acceptance Rate: {acceptance_rate:.2f}%")

  3%|██▌                                                                                   | 3/100 [00:00<00:24,  3.91it/s]

Batch 0 | Accepted: 100/100
Batch 1 | Accepted: 100/100
Batch 2 | Accepted: 100/100
Batch 3 | Accepted: 100/100


  7%|██████                                                                                | 7/100 [00:01<00:10,  8.73it/s]

Batch 4 | Accepted: 100/100
Batch 5 | Accepted: 100/100
Batch 6 | Accepted: 100/100
Batch 7 | Accepted: 100/100


 11%|█████████▎                                                                           | 11/100 [00:01<00:07, 12.35it/s]

Batch 8 | Accepted: 100/100
Batch 9 | Accepted: 100/100
Batch 10 | Accepted: 100/100
Batch 11 | Accepted: 100/100


 16%|█████████████▌                                                                       | 16/100 [00:01<00:05, 16.07it/s]

Batch 12 | Accepted: 100/100
Batch 13 | Accepted: 100/100
Batch 14 | Accepted: 100/100
Batch 15 | Accepted: 100/100
Batch 16 | Accepted: 100/100


 22%|██████████████████▋                                                                  | 22/100 [00:01<00:04, 18.57it/s]

Batch 17 | Accepted: 100/100
Batch 18 | Accepted: 100/100
Batch 19 | Accepted: 100/100
Batch 20 | Accepted: 100/100
Batch 21 | Accepted: 100/100


 25%|█████████████████████▎                                                               | 25/100 [00:02<00:03, 19.35it/s]

Batch 22 | Accepted: 100/100
Batch 23 | Accepted: 100/100
Batch 24 | Accepted: 100/100
Batch 25 | Accepted: 100/100
Batch 26 | Accepted: 100/100


 31%|██████████████████████████▎                                                          | 31/100 [00:02<00:03, 20.75it/s]

Batch 27 | Accepted: 100/100
Batch 28 | Accepted: 100/100
Batch 29 | Accepted: 100/100
Batch 30 | Accepted: 100/100
Batch 31 | Accepted: 100/100


 37%|███████████████████████████████▍                                                     | 37/100 [00:02<00:02, 21.56it/s]

Batch 32 | Accepted: 100/100
Batch 33 | Accepted: 100/100
Batch 34 | Accepted: 100/100
Batch 35 | Accepted: 100/100
Batch 36 | Accepted: 100/100


 40%|██████████████████████████████████                                                   | 40/100 [00:02<00:02, 21.75it/s]

Batch 37 | Accepted: 100/100
Batch 38 | Accepted: 100/100
Batch 39 | Accepted: 99/100
Batch 40 | Accepted: 100/100
Batch 41 | Accepted: 100/100


 46%|███████████████████████████████████████                                              | 46/100 [00:03<00:02, 22.01it/s]

Batch 42 | Accepted: 100/100
Batch 43 | Accepted: 100/100
Batch 44 | Accepted: 100/100
Batch 45 | Accepted: 100/100
Batch 46 | Accepted: 100/100


 52%|████████████████████████████████████████████▏                                        | 52/100 [00:03<00:02, 22.04it/s]

Batch 47 | Accepted: 100/100
Batch 48 | Accepted: 100/100
Batch 49 | Accepted: 100/100
Batch 50 | Accepted: 100/100
Batch 51 | Accepted: 100/100


 55%|██████████████████████████████████████████████▊                                      | 55/100 [00:03<00:02, 22.12it/s]

Batch 52 | Accepted: 100/100
Batch 53 | Accepted: 100/100
Batch 54 | Accepted: 100/100
Batch 55 | Accepted: 100/100
Batch 56 | Accepted: 100/100


 61%|███████████████████████████████████████████████████▊                                 | 61/100 [00:03<00:01, 22.20it/s]

Batch 57 | Accepted: 100/100
Batch 58 | Accepted: 100/100
Batch 59 | Accepted: 100/100
Batch 60 | Accepted: 100/100
Batch 61 | Accepted: 100/100


 67%|████████████████████████████████████████████████████████▉                            | 67/100 [00:03<00:01, 22.29it/s]

Batch 62 | Accepted: 100/100
Batch 63 | Accepted: 100/100
Batch 64 | Accepted: 100/100
Batch 65 | Accepted: 100/100
Batch 66 | Accepted: 100/100


 70%|███████████████████████████████████████████████████████████▍                         | 70/100 [00:04<00:01, 22.35it/s]

Batch 67 | Accepted: 100/100
Batch 68 | Accepted: 100/100
Batch 69 | Accepted: 100/100
Batch 70 | Accepted: 100/100
Batch 71 | Accepted: 100/100


 76%|████████████████████████████████████████████████████████████████▌                    | 76/100 [00:04<00:01, 22.35it/s]

Batch 72 | Accepted: 100/100
Batch 73 | Accepted: 100/100
Batch 74 | Accepted: 100/100
Batch 75 | Accepted: 100/100
Batch 76 | Accepted: 100/100


 82%|█████████████████████████████████████████████████████████████████████▋               | 82/100 [00:04<00:00, 22.36it/s]

Batch 77 | Accepted: 100/100
Batch 78 | Accepted: 100/100
Batch 79 | Accepted: 100/100
Batch 80 | Accepted: 100/100
Batch 81 | Accepted: 100/100


 85%|████████████████████████████████████████████████████████████████████████▎            | 85/100 [00:04<00:00, 22.30it/s]

Batch 82 | Accepted: 100/100
Batch 83 | Accepted: 100/100
Batch 84 | Accepted: 100/100
Batch 85 | Accepted: 100/100
Batch 86 | Accepted: 100/100


 91%|█████████████████████████████████████████████████████████████████████████████▎       | 91/100 [00:05<00:00, 22.14it/s]

Batch 87 | Accepted: 100/100
Batch 88 | Accepted: 100/100
Batch 89 | Accepted: 100/100
Batch 90 | Accepted: 100/100
Batch 91 | Accepted: 100/100


 97%|██████████████████████████████████████████████████████████████████████████████████▍  | 97/100 [00:05<00:00, 22.17it/s]

Batch 92 | Accepted: 100/100
Batch 93 | Accepted: 100/100
Batch 94 | Accepted: 100/100
Batch 95 | Accepted: 100/100
Batch 96 | Accepted: 100/100


100%|████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.12it/s]

Batch 97 | Accepted: 100/100
Batch 98 | Accepted: 100/100
Batch 99 | Accepted: 100/100

Overall Acceptance Rate: 99.99%
Accuracy among Accepted Samples: 10.00%
Weighted Vote Accuracy among accepted samples: 10.00%
Overall Acceptance Rate: 99.99%



