## Dependencies


In [None]:
import sys
import json
from pathlib import Path
from torchvision import datasets

## Utils


In [None]:
project_root = Path.cwd().resolve().parents[2]
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

data_root = project_root / "data"
data_root.mkdir(parents=True, exist_ok=True)

from notebook_setup import setup_notebook

setup_notebook()

In [None]:
# Device configuration
from utils.utils import DEVICE

print(f"Device used: {DEVICE}")

# Set random seed for reproducibility
from utils.utils import set_seed

set_seed()

# Init model
from models.simple_cnn import init_model_cnn

# Create sisa structure
from methods.sisa.sisa_utils import create_sisa_structure

# Recreate sisa dataloaders from json file
from methods.sisa.sisa_utils import recreate_sisa_dataloaders

# SISA train & test loop
from methods.sisa.sisa_train_test import sisa_train, sisa_test, retrain_sisa_framework

# Aggregate SISA models, via weighted voting
from methods.sisa.sisa_utils import evaluate_aggregated_model, update_sisa_structure

## Parameters (arbitrary chosen)


In [None]:
BATCH_SIZE = 64

LEARNING_RATE = 0.001

EPOCHS = 10

SHARDS = 3

SLICES = 5

stt.MULTI_GPU = True

# SISA structure


## Get data transform


In [None]:
*_, transform = init_model_cnn()
# transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

## Load dataset


In [None]:
train_dataset = datasets.MNIST(
    root=data_root, train=True, transform=transform, download=True
)
test_dataset = datasets.MNIST(
    root=data_root, train=False, transform=transform, download=True
)

## Create SISA structure


In [None]:
create_sisa_structure(train_dataset, shards=SHARDS, slices_per_shard=SLICES)

In [None]:
sisa_structure_file = "sisa_structure.json"

## Prepare DataLoaders from json file


In [None]:
dataloaders, classes = recreate_sisa_dataloaders(
    datasets=(train_dataset, test_dataset),
    info_file_path=sisa_structure_file,
    batch_size=BATCH_SIZE,
    val_ratio=0.1,
)

# Train using SISA framework


## Models output path


In [None]:
save_models_metrics_dir = "./sisa_models"

## Call train


In [None]:
import time

start_time = time.perf_counter()

sisa_train(
    dataloaders=dataloaders,
    num_epochs=EPOCHS,
    save_models_metrics_dir=save_models_metrics_dir,
    init_model_func=init_model_cnn,
    learning_rate=LEARNING_RATE,
)

end_time = time.perf_counter()  # End timer
elapsed_time = end_time - start_time

print(f"Execution time: {elapsed_time:.6f} seconds")

## Call test


In [None]:
sisa_test(
    dataloaders=dataloaders,
    saved_models_metrics_dir=save_models_metrics_dir,
    init_model_func=init_model_cnn,
    clear_solo_models_preds=True,
)

## Shard models assembly (weighted voting, w=accuracy)


In [None]:
with open("sisa_final_evaluation.json", "r") as f:
    results = json.load(f)

In [None]:
evaluate_aggregated_model(results, classes)

# SISA re-train


## Paths to data info


In [None]:
mnist_samples_to_delete = Path("..") / "naive" / "mnist_samples_to_unlearn_30per.json"
mnist_sisa_structure = "sisa_structure.json"
updated_sisa_structure = "updated_sisa_strucute.json"
deleted_samples = "deleted_samples.json"

## Update sisa structure -> delete samples


In [None]:
affected_shards = update_sisa_structure(
    mnist_samples_to_delete,
    mnist_sisa_structure,
    updated_sisa_structure,
    deleted_samples,
)

## Recreate DataLoaders from json file


In [None]:
dataloaders, classes = recreate_sisa_dataloaders(
    datasets=(train_dataset, test_dataset),
    info_file_path=updated_sisa_structure,
    batch_size=BATCH_SIZE,
    val_ratio=0.1,
)

## Models output path


In [None]:
save_path = "./sisa_updated_models"

## Retrain the affected shards


In [None]:
import time

start_time = time.perf_counter()

retrain_sisa_framework(
    dataloaders=dataloaders,
    affected_shards=affected_shards,
    num_epochs=EPOCHS,
    save_models_metrics_dir=save_path,
    init_model_func=init_model_cnn,
    learning_rate=LEARNING_RATE,
)

end_time = time.perf_counter()  # End timer
elapsed_time = end_time - start_time

print(f"Execution time: {elapsed_time:.6f} seconds")

## Re-test new SISA models


In [None]:
sisa_test(
    dataloaders=dataloaders,
    saved_models_metrics_dir=save_path,
    init_model_func=init_model_cnn,
    clear_solo_models_preds=True,
)

## Shard models assembly (weighted voting, w=accuracy)


In [None]:
with open("sisa_final_evaluation.json", "r") as f:
    results = json.load(f)

In [None]:
evaluate_aggregated_model(results, classes)