In [230]:
import random
import numpy as np
import torch

def set_seed(seed):
    random.seed(seed)            # Python random
    np.random.seed(seed)         # NumPy
    torch.manual_seed(seed)      # CPU seed

    # For GPU (CUDA)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    # For MPS (Apple Silicon)
    if torch.backends.mps.is_available():
        torch.mps.manual_seed(seed)

    # Make CuDNN deterministic (if using CUDA)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Data Loading

In [231]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path

set_seed(42)
class ShardDataset(Dataset):
    def __init__(self, shard_path, transform=None):
        """
        shard_path: path to a single shard folder
        transform: torchvision transforms
        """
        self.shard_path = Path(shard_path)
        self.samples = []

        # read all images
        for class_folder in self.shard_path.iterdir():
            if class_folder.is_dir():
                for img_path in class_folder.iterdir():
                    if img_path.suffix.lower() in [".jpg", ".jpeg", ".png"]:
                        self.samples.append((img_path, class_folder.name))

        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

In [232]:
from torch.utils.data import DataLoader
from torchvision import transforms


# transforms for all shards
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

num_shards = 3
shard_loaders = []

set_seed(42)
for shard_idx in range(num_shards):
    shard_path = f"shards/shard_{shard_idx}"
    dataset = ShardDataset(shard_path, transform=transform)
    loader = DataLoader(dataset, batch_size=50, shuffle=True)
    shard_loaders.append(loader)
    print(f"Shard {shard_idx} size: {len(dataset)} images")

Shard 0 size: 900 images
Shard 1 size: 900 images
Shard 2 size: 850 images


In [233]:

class FullDataset(Dataset):
    def __init__(self, shards_dir, transform=None):
        """
        shards_dir: path to the parent folder containing all shards
        transform: torchvision transforms
        """
        self.samples = []
        self.transform = transform
        shards_dir = Path(shards_dir)

        # iterate all shards and all class folders
        for shard_folder in shards_dir.iterdir():
            if shard_folder.is_dir():
                for class_folder in shard_folder.iterdir():
                    if class_folder.is_dir():
                        for img_path in class_folder.iterdir():
                            if img_path.suffix.lower() in [".jpg", ".jpeg", ".png"]:
                                # store (path, class_name)
                                self.samples.append((img_path, class_folder.name))

        # build class->index mapping
        class_names = sorted({cls for _, cls in self.samples})
        self.class_to_idx = {cls: idx for idx, cls in enumerate(class_names)}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        label_idx = self.class_to_idx[label]
        return img, label_idx

In [234]:
from torch.utils.data import DataLoader
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

set_seed(42)
full_dataset = FullDataset(shards_dir="shards", transform=transform)
full_loader = DataLoader(full_dataset, batch_size=32, shuffle=True)
full_test_loader = DataLoader(full_dataset, batch_size=32, shuffle=False)


print(f"Full dataset size: {len(full_dataset)}")

Full dataset size: 2650


# SISA Modeling

In [236]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
import os
import time

n_epochs = 50

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

num_classes = 3  # your dataset

# Example: small CNN
def get_model():
    set_seed(42)
    model = models.resnet18(weights=None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model.to(device)

shard_models = []

class_names = sorted([p.name for p in Path("shards/shard_0").iterdir() if p.is_dir()])
class_to_idx = {name: idx for idx, name in enumerate(class_names)}
print("Class mapping:", class_to_idx)

subfolder_path = "my_models/"

start_time = time.time()
set_seed(42)
for shard_idx, loader in enumerate(shard_loaders):
    start = time.time()

    print(f"\nTraining model for shard {shard_idx}")
    model = get_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    
    # simple epoch loop
    for epoch in range(n_epochs):  # small number of epochs for demo
        for imgs, labels in loader:
            imgs = imgs.to(device)
            # convert string labels to integers
            labels_idx = torch.tensor([class_to_idx[lbl] for lbl in labels], dtype=torch.long).to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels_idx)
            loss.backward()
            optimizer.step()
    shard_models.append(model)

    end = time.time()
    print(f"Shard Training Time: {end - start:.5f} seconds")
    name = "model_"+ str(shard_idx) + ".pt"
    full_path = os.path.join(subfolder_path, name)
    torch.save(model.state_dict(), full_path)
    print(f"Total Training Time: {end - start:.5f} seconds")

Class mapping: {'cat': 0, 'dog': 1, 'horse': 2}

Training model for shard 0
Shard Training Time: 176.06704 seconds
Total Training Time: 176.06704 seconds

Training model for shard 1
Shard Training Time: 96.21077 seconds
Total Training Time: 96.21077 seconds

Training model for shard 2
Shard Training Time: 88.92099 seconds
Total Training Time: 88.92099 seconds


## Prediction Functions

In [237]:
import torch

def ensemble_predict(models, input_tensor, device):
    """
    models: list of shard models
    input_tensor: batch of images (already on device)
    returns: ensemble predictions (argmax over average logits)
    """
    with torch.no_grad():
        outputs = []
        for model in models:
            model.eval()
            model.to(device)
            outputs.append(model(input_tensor))
        avg_output = sum(outputs)
        # avg_output = sum(outputs) / len(outputs)
        preds = avg_output.argmax(dim=1)
    return preds

def evaluate_ensemble(shard_models, test_loader, class_names, device):
    num_classes = len(class_names)

    # overall accuracy
    correct = 0
    total = 0

    # class-wise tracking
    class_correct = torch.zeros(num_classes)
    class_total   = torch.zeros(num_classes)

    for imgs, labels in test_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        preds = ensemble_predict(shard_models, imgs, device)

        # overall accuracy
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        # class-wise accuracy
        for i in range(len(labels)):
            label = labels[i].item()
            class_total[label] += 1
            if preds[i].item() == label:
                class_correct[label] += 1

    # overall accuracy
    overall_acc = 100 * correct / total

    # per-class accuracy
    class_acc = {}
    for i in range(num_classes):
        acc = 100 * class_correct[i] / class_total[i] if class_total[i] > 0 else 0
        class_acc[class_names[i]] = acc.item()

    return overall_acc, class_acc

In [238]:
overall, per_class = evaluate_ensemble(shard_models, full_test_loader, class_names, device)

print(f"SISA Ensemble Accuracy: {overall:.2f}%\n")
print("Class-by-Class Accuracy:")
for cls, acc in per_class.items():
    print(f"{cls}: {acc:.2f}%")

SISA Ensemble Accuracy: 67.77%

Class-by-Class Accuracy:
cat: 48.71%
dog: 76.00%
horse: 77.56%


In [239]:
# remove entire shard model

def remove_index(lst, idx):
    return lst[:idx] + lst[idx+1:]

dropped_cat_models = remove_index(shard_models, 0)

overall, per_class = evaluate_ensemble(dropped_cat_models, full_test_loader, class_names, device)

print(f"SISA Ensemble Accuracy: {overall:.2f}%\n")
print("Class-by-Class Accuracy:")
for cls, acc in per_class.items():
    print(f"{cls}: {acc:.2f}%")

SISA Ensemble Accuracy: 60.04%

Class-by-Class Accuracy:
cat: 11.41%
dog: 91.67%
horse: 74.33%


### Retrain Single Shard

In [240]:
class DroppedShardDataset(Dataset):
    def __init__(self, shard_path, transform=None, drop_class=None):
        self.transform = transform
        self.samples = []

        shard_path = Path(shard_path)

        # Load samples
        for cls_folder in shard_path.iterdir():
            if cls_folder.is_dir():
                cls_name = cls_folder.name
                # Skip filtered class
                if drop_class is not None and cls_name == drop_class:
                    continue

                for img in cls_folder.glob("*.jpg"):
                    self.samples.append((img, cls_name))
                for img in cls_folder.glob("*.jpeg"):
                    self.samples.append((img, cls_name))
                for img in cls_folder.glob("*.png"):
                    self.samples.append((img, cls_name))

        # Rebuild class_to_idx after dropping
        self.class_to_idx = {c: i for i, c in enumerate(sorted({c for _, c in self.samples}))}

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

In [241]:
dropped_dataset = DroppedShardDataset("shards/shard_0", transform=transform, drop_class="cat")
loader = DataLoader(dropped_dataset, batch_size=32, shuffle=True)

In [242]:
def retrain_shard(dataloader,
    epochs=100,
    lr=1e-4,
    device="mps",
    
):
    start = time.time()
    # Build model
    model = get_model()

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Training loop
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0

        for imgs, labels in dataloader:
            imgs = imgs.to(device)
            labels_idx = torch.tensor([class_to_idx[lbl] for lbl in labels], dtype=torch.long).to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels_idx)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

    end = time.time()
    print(f"Shard Training Time: {end - start:.5f} seconds")
    return model

In [243]:
retrain_model = retrain_shard(loader)

full_path = "my_models/retrained_model_0.pt"
torch.save(retrain_model.state_dict(), full_path)

model1 = [retrain_model]
evaluate_ensemble(model1, full_test_loader, class_names, device)

Shard Training Time: 36.28261 seconds


(51.283018867924525,
 {'cat': 0.0, 'dog': 70.44444274902344, 'horse': 80.55555725097656})

In [244]:
shard_models[0] = model

overall, per_class = evaluate_ensemble(shard_models, full_test_loader, class_names, device)

print(f"SISA Ensemble Accuracy: {overall:.2f}%\n")
print("Class-by-Class Accuracy:")
for cls, acc in per_class.items():
    print(f"{cls}: {acc:.2f}%")

SISA Ensemble Accuracy: 49.51%

Class-by-Class Accuracy:
cat: 0.71%
dog: 45.11%
horse: 100.00%


# Full Model Training

In [203]:
import torch.optim as optim
import torch
import torch.nn as nn
import torchvision.models as models

set_seed(42)
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
num_classes = len(full_dataset.class_to_idx)

model_full = models.resnet18(weights=None)
model_full.fc = nn.Linear(model_full.fc.in_features, num_classes)
model_full = model_full.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_full.parameters(), lr=1e-3)

epochs = 5

start = time.time()

for epoch in range(epochs):
    model_full.train()
    running_loss = 0.0
    for imgs, labels in full_loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model_full(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    model_full.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for imgs, labels in full_test_loader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model_full(imgs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    end = time.time()
    print(f"Epoch Training Time: {end - start:.5f} seconds")
    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(full_loader):.4f}")
    print(f"Full model accuracy: {100 * correct / total:.2f}%")
print(f"Total Training Time: {end - start:.5f} seconds")

# full_path = "my_models/full_model.pt"
# torch.save(model_full.state_dict(), full_path)


Epoch Training Time: 30.06417 seconds
Epoch 1/5, Loss: 1.3612
Full model accuracy: 28.70%
Epoch Training Time: 59.49050 seconds
Epoch 2/5, Loss: 1.0335
Full model accuracy: 44.82%
Epoch Training Time: 89.44748 seconds
Epoch 3/5, Loss: 0.8667
Full model accuracy: 62.52%
Epoch Training Time: 119.59852 seconds
Epoch 4/5, Loss: 0.7210
Full model accuracy: 65.69%
Epoch Training Time: 148.99872 seconds
Epoch 5/5, Loss: 0.5715
Full model accuracy: 82.52%
Total Training Time: 148.99872 seconds
