In [17]:
import random
import numpy as np
import torch
import torchvision.models as models
import torch.nn as nn
from pathlib import Path
from torch.utils.data import Dataset, DataLoader
from PIL import Image

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")



def set_seed(seed):
    random.seed(seed)            # Python random
    np.random.seed(seed)         # NumPy
    torch.manual_seed(seed)      # CPU seed

    # For GPU (CUDA)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    # For MPS (Apple Silicon)
    if torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)

    # Make CuDNN deterministic (if using CUDA)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [18]:

class FullDataset(Dataset):
    def __init__(self, shards_dir, transform=None):
        """
        shards_dir: path to the parent folder containing all shards
        transform: torchvision transforms
        """
        self.samples = []
        self.transform = transform
        shards_dir = Path(shards_dir)

        # iterate all shards and all class folders
        for shard_folder in shards_dir.iterdir():
            if shard_folder.is_dir():
                for class_folder in shard_folder.iterdir():
                    if class_folder.is_dir():
                        for img_path in class_folder.iterdir():
                            if img_path.suffix.lower() in [".jpg", ".jpeg", ".png"]:
                                # store (path, class_name)
                                self.samples.append((img_path, class_folder.name))

        # build class->index mapping
        class_names = sorted({cls for _, cls in self.samples})
        self.class_to_idx = {cls: idx for idx, cls in enumerate(class_names)}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label_idx = self.class_to_idx[label]
        return img, label_idx

In [19]:
num_classes = 3  # your dataset

def get_model():
    set_seed(42)
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(device)

In [20]:
import torch
from pathlib import Path

# Folder containing your model files
model_folder = Path("my_models")

# List all .pt files starting with 'model_'
model_files = sorted(model_folder.glob("model_*.pt"))

# List to store loaded models
models_essemble = []

for file in model_files:
    model = get_model()  # initialize architecture
    state_dict = torch.load(file, map_location=device)  # load the saved weights
    model.load_state_dict(state_dict)  # load weights into model
    models_essemble.append(model) 

print(f"Loaded {len(models_essemble)} models:")
for i, m in enumerate(models_essemble):
    print(f"Model {i}: {type(m)}")

Loaded 3 models:
Model 0: <class 'torchvision.models.resnet.ResNet'>
Model 1: <class 'torchvision.models.resnet.ResNet'>
Model 2: <class 'torchvision.models.resnet.ResNet'>


In [21]:
from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

set_seed(42)
full_dataset = FullDataset(shards_dir="shards", transform=transform)
full_test_loader = DataLoader(full_dataset, batch_size=32, shuffle=False)

In [22]:
import torch

def ensemble_predict(models, input_tensor, device):
    """
    models: list of shard models
    input_tensor: batch of images (already on device)
    returns: ensemble predictions (argmax over average logits)
    """
    with torch.no_grad():
        outputs = []
        for model in models:
            model.eval()
            model.to(device)
            outputs.append(model(input_tensor))
        avg_output = sum(outputs)
        # avg_output = sum(outputs) / len(outputs)
        preds = avg_output.argmax(dim=1)
    return preds

def evaluate_ensemble(shard_models, test_loader, class_names, device):
    num_classes = len(class_names)

    # overall accuracy
    correct = 0
    total = 0

    # class-wise tracking
    class_correct = torch.zeros(num_classes)
    class_total   = torch.zeros(num_classes)

    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        preds = ensemble_predict(shard_models, imgs, device)

        # overall accuracy
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        # class-wise accuracy
        for i in range(len(labels)):
            label = labels[i].item()
            class_total[label] += 1
            if preds[i].item() == label:
                class_correct[label] += 1

    # overall accuracy
    overall_acc = 100 * correct / total

    # per-class accuracy
    class_acc = {}
    for i in range(num_classes):
        acc = 100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
        class_acc[class_names[i]] = acc.item()

    return overall_acc, class_acc

In [23]:
class_names = sorted([p.name for p in Path("shards/shard_0").iterdir() if p.is_dir()])
overall, per_class = evaluate_ensemble(models_essemble, full_test_loader, class_names, device)

print(f"SISA Ensemble Accuracy: {overall:.2f}%\n")
print("Class-by-Class Accuracy:")
for cls, acc in per_class.items():
    print(f"{cls}: {acc:.2f}%")

SISA Ensemble Accuracy: 67.77%

Class-by-Class Accuracy:
cat: 48.71%
dog: 76.00%
horse: 77.56%


In [24]:
retrained_model = get_model()
state_dict = torch.load("my_models/retrained_model_0.pt", map_location=device)  # load the saved weights
retrained_model.load_state_dict(state_dict)  # load weights into model
models_essemble[0] = retrained_model 

In [25]:
class_names = sorted([p.name for p in Path("shards/shard_0").iterdir() if p.is_dir()])
overall, per_class = evaluate_ensemble(models_essemble, full_test_loader, class_names, device)

print(f"SISA Ensemble Accuracy: {overall:.2f}%\n")
print("Class-by-Class Accuracy:")
for cls, acc in per_class.items():
    print(f"{cls}: {acc:.2f}%")

SISA Ensemble Accuracy: 57.85%

Class-by-Class Accuracy:
cat: 3.76%
dog: 88.78%
horse: 78.00%
