In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Using device:", device)

Using device: cuda


In [2]:
# Normalization values for CIFAR-10
mean = (0.4914, 0.4822, 0.4465)
std  = (0.2023, 0.1994, 0.2010)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

trainset = torchvision.datasets.FashionMNIST(
    root='./data', train=True, download=True, transform=transform_train
)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testset = torchvision.datasets.FashionMNIST(
    root='./data', train=False, download=True, transform=transform_test
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2
)

In [3]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels,
                               kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels,
                               kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut (projection) if shape changes (stride != 1 or channels differ)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out


In [4]:
class ResNet20(nn.Module):
    def __init__(self, block=BasicBlock, num_classes=10):
        super(ResNet20, self).__init__()
        self.in_channels = 16

        # Initial conv: 3x3, 16 filters
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)

        # Each "layer" is a sequence of blocks
        self.layer1 = self._make_layer(block, out_channels=16, blocks=3, stride=1)
        self.layer2 = self._make_layer(block, out_channels=32, blocks=3, stride=2)
        self.layer3 = self._make_layer(block, out_channels=64, blocks=3, stride=2)

        # Global average pool and final classification layer
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride):
        """Create a stage of 'blocks' residual blocks."""
        strides = [stride] + [1] * (blocks - 1)  
        layers = []
        for s in strides:
            layers.append(block(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avgpool(out)
        out = torch.flatten(out, 1)
        out = self.fc(out)
        return out


In [5]:
def train_one_epoch(model, loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc


In [6]:
model = ResNet20().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

In [7]:
#transformations for metamorphic testing
transformations = [
    transforms.RandomRotation(degrees=10),  # Small rotation
    transforms.ColorJitter(brightness=0.2),  # Slight brightness change
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1))  # Small translation
]

In [8]:
import torch
import torch.nn.functional as F

def majority_vote_prediction(model, images, transformations, device):
    """
    Given a batch of images and a list of transformations,
    compute the model's predictions on:
        1) the original images,
        2) each transformed version,
    then take a majority vote across all predictions for each sample.
    
    Returns a tensor of shape [batch_size] with the final voted class.
    """
    model.eval()

    # 1) Original predictions
    with torch.no_grad():
        original_out = model(images)
        _, original_preds = torch.max(F.softmax(original_out, dim=1), dim=1)  # shape [batch_size]

    # 2) Collect predictions in a list
    all_preds = [original_preds]

    # 3) For each metamorphic transform
    for tf in transformations:
        # apply transform on CPU, then move back to device
        x_tf = tf(images.cpu()).to(device)

        with torch.no_grad():
            out_tf = model(x_tf)
            _, preds_tf = torch.max(F.softmax(out_tf, dim=1), dim=1)

        all_preds.append(preds_tf)

    # 4) Stack them into shape [num_transforms+1, batch_size]
    stacked_preds = torch.stack(all_preds, dim=0)

    # 5) Majority vote across the first dimension
    voted_preds = stacked_preds.mode(dim=0).values  # shape [batch_size]
    return voted_preds

In [9]:
from tqdm import tqdm

def test_with_majority_vote(model, loader, transformations, device):
    """
    Evaluates the model on 'loader' using majority-voted predictions (original + transforms).
    Prints per-batch accuracy and returns the final overall accuracy.
    """
    model.eval()
    correct = 0
    total = 0

    for i, (images, labels) in enumerate(tqdm(loader)):
        images, labels = images.to(device), labels.to(device)

        # Get majority-voted predictions
        voted_preds = majority_vote_prediction(model, images, transformations, device)

        # Compare to ground-truth
        batch_correct = (voted_preds == labels).sum().item()
        batch_size = labels.size(0)

        correct += batch_correct
        total += batch_size

        # Print a line for each batch
        print(f"Batch {i} | MajVot Accuracy: {100.0 * batch_correct / batch_size:.2f}% | Batch Size: {batch_size}")

    # Final overall accuracy
    overall_accuracy = 100.0 * correct / total
    print(f"\nFinal MajVot Accuracy (across all test samples): {overall_accuracy:.2f}%")
    return overall_accuracy

In [13]:
majvot_acc = test_with_majority_vote(model, testloader, transformations, device)
print(f"Majority Vote final accuracy: {majvot_acc:.2f}%")

  4%|███▍                                                                                  | 4/100 [00:00<00:06, 13.76it/s]

Batch 0 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 1 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 2 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 3 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 4 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 5 | MajVot Accuracy: 5.00% | Batch Size: 100
Batch 6 | MajVot Accuracy: 11.00% | Batch Size: 100


 12%|██████████▏                                                                          | 12/100 [00:00<00:03, 27.30it/s]

Batch 7 | MajVot Accuracy: 12.00% | Batch Size: 100
Batch 8 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 9 | MajVot Accuracy: 13.00% | Batch Size: 100
Batch 10 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 11 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 12 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 13 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 14 | MajVot Accuracy: 13.00% | Batch Size: 100


 22%|██████████████████▋                                                                  | 22/100 [00:00<00:02, 35.32it/s]

Batch 15 | MajVot Accuracy: 6.00% | Batch Size: 100
Batch 16 | MajVot Accuracy: 10.00% | Batch Size: 100
Batch 17 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 18 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 19 | MajVot Accuracy: 10.00% | Batch Size: 100
Batch 20 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 21 | MajVot Accuracy: 6.00% | Batch Size: 100
Batch 22 | MajVot Accuracy: 12.00% | Batch Size: 100
Batch 23 | MajVot Accuracy: 13.00% | Batch Size: 100


 32%|███████████████████████████▏                                                         | 32/100 [00:01<00:01, 38.19it/s]

Batch 24 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 25 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 26 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 27 | MajVot Accuracy: 10.00% | Batch Size: 100
Batch 28 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 29 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 30 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 31 | MajVot Accuracy: 12.00% | Batch Size: 100
Batch 32 | MajVot Accuracy: 10.00% | Batch Size: 100


 41%|██████████████████████████████████▊                                                  | 41/100 [00:01<00:01, 37.89it/s]

Batch 33 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 34 | MajVot Accuracy: 6.00% | Batch Size: 100
Batch 35 | MajVot Accuracy: 12.00% | Batch Size: 100
Batch 36 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 37 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 38 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 39 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 40 | MajVot Accuracy: 7.00% | Batch Size: 100


 46%|███████████████████████████████████████                                              | 46/100 [00:01<00:01, 38.69it/s]

Batch 41 | MajVot Accuracy: 6.00% | Batch Size: 100
Batch 42 | MajVot Accuracy: 15.00% | Batch Size: 100
Batch 43 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 44 | MajVot Accuracy: 15.00% | Batch Size: 100
Batch 45 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 46 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 47 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 48 | MajVot Accuracy: 12.00% | Batch Size: 100
Batch 49 | MajVot Accuracy: 10.00% | Batch Size: 100


 55%|██████████████████████████████████████████████▊                                      | 55/100 [00:01<00:01, 38.15it/s]

Batch 50 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 51 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 52 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 53 | MajVot Accuracy: 12.00% | Batch Size: 100
Batch 54 | MajVot Accuracy: 12.00% | Batch Size: 100
Batch 55 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 56 | MajVot Accuracy: 5.00% | Batch Size: 100
Batch 57 | MajVot Accuracy: 3.00% | Batch Size: 100


 64%|██████████████████████████████████████████████████████▍                              | 64/100 [00:01<00:00, 39.08it/s]

Batch 58 | MajVot Accuracy: 6.00% | Batch Size: 100
Batch 59 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 60 | MajVot Accuracy: 5.00% | Batch Size: 100
Batch 61 | MajVot Accuracy: 13.00% | Batch Size: 100
Batch 62 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 63 | MajVot Accuracy: 6.00% | Batch Size: 100
Batch 64 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 65 | MajVot Accuracy: 12.00% | Batch Size: 100


 73%|██████████████████████████████████████████████████████████████                       | 73/100 [00:02<00:00, 38.78it/s]

Batch 66 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 67 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 68 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 69 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 70 | MajVot Accuracy: 14.00% | Batch Size: 100
Batch 71 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 72 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 73 | MajVot Accuracy: 10.00% | Batch Size: 100


 83%|██████████████████████████████████████████████████████████████████████▌              | 83/100 [00:02<00:00, 41.22it/s]

Batch 74 | MajVot Accuracy: 10.00% | Batch Size: 100
Batch 75 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 76 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 77 | MajVot Accuracy: 6.00% | Batch Size: 100
Batch 78 | MajVot Accuracy: 10.00% | Batch Size: 100
Batch 79 | MajVot Accuracy: 14.00% | Batch Size: 100
Batch 80 | MajVot Accuracy: 16.00% | Batch Size: 100
Batch 81 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 82 | MajVot Accuracy: 8.00% | Batch Size: 100


 88%|██████████████████████████████████████████████████████████████████████████▊          | 88/100 [00:02<00:00, 40.89it/s]

Batch 83 | MajVot Accuracy: 12.00% | Batch Size: 100
Batch 84 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 85 | MajVot Accuracy: 11.00% | Batch Size: 100
Batch 86 | MajVot Accuracy: 10.00% | Batch Size: 100
Batch 87 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 88 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 89 | MajVot Accuracy: 8.00% | Batch Size: 100
Batch 90 | MajVot Accuracy: 5.00% | Batch Size: 100
Batch 91 | MajVot Accuracy: 12.00% | Batch Size: 100


100%|████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 36.37it/s]

Batch 92 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 93 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 94 | MajVot Accuracy: 7.00% | Batch Size: 100
Batch 95 | MajVot Accuracy: 13.00% | Batch Size: 100
Batch 96 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 97 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 98 | MajVot Accuracy: 9.00% | Batch Size: 100
Batch 99 | MajVot Accuracy: 7.00% | Batch Size: 100

Final MajVot Accuracy (across all test samples): 9.34%
Majority Vote final accuracy: 9.34%





In [14]:
def validate(model, testloader, device, criterion):
    model.eval()
    val_running_loss = 0.0
    val_running_correct = 0
    all_preds = []
    all_targets = []
    
    # Switch off gradient tracking for validation
    with torch.no_grad():
        for data, target in testloader:
            data, target = data.to(device), target.to(device)
            
            # Forward pass
            output = model(data)
            loss = criterion(output, target)
            
            # Accumulate loss * batch_size for accurate average
            val_running_loss += loss.item() * data.size(0)
            
            # Predictions
            _, preds = torch.max(output, 1)
            val_running_correct += (preds == target).sum().item()
            
            # Store for metrics
            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(target.cpu().numpy())
    
    # Calculate average loss
    total_samples = len(testloader.dataset)
    val_loss = val_running_loss / total_samples
    
    # Calculate accuracy
    val_accuracy = 100.0 * val_running_correct / total_samples
    
    # Calculate precision, recall, F1 (weighted or macro—your choice)
    precision = precision_score(all_targets, all_preds, average='weighted')
    recall    = recall_score(all_targets, all_preds, average='weighted')
    f1        = f1_score(all_targets, all_preds, average='weighted')
    
    return val_loss, val_accuracy, precision, recall, f1
