In [2]:
%env CUDA_VISIBLE_DEVICES=0
import os
import time
import pickle
import random
from functools import partial

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

from torch.multiprocessing import Process
from torch.utils.tensorboard import SummaryWriter
import torchvision

from training_utils import train_with_centerclip, NormalParticipant
from attacks import SignFlipper, LabelFlipper, ConstantDirection, DelayedGradientAttacker
from resnet import ResNet18

torch.set_num_threads(1)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(random.randint(30000, 60000))

# pre-download datasets
torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
torchvision.datasets.CIFAR10(root='./data', train=False, download=True)


def run_worker(config, rank: int, backend='gloo'):
    dist.init_process_group(backend, rank=rank, world_size=config.NUM_WORKERS)
    
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    if rank == 0:
        writer = SummaryWriter('./{}/{}_rank{}_{}.{:0>2d}.{:0>2d}_{:0>2d}:{:0>2d}:{:0>2d}'.format(
                LOGS_FOLDER, config.EXP_NAME, rank, *time.gmtime()[:6]))
        writer.add_text('config', '\n'.join(f'{k}: {v}' for k, v in config.__dict__.items()
                                            if not k.startswith('_')))
    else:
        writer = None

    model, opt = train_with_centerclip(config, device, writer, verbose=1)

env: CUDA_VISIBLE_DEVICES=0
Files already downloaded and verified
Files already downloaded and verified


In [None]:
for seed in range(5):
    for num_byzantines, attack_every in [(7, 1), (3, 1), (7, 10)]:
        for tau in [1.0, 10.0]:
            for attack_start in ['1k', '10k']:
                for name, attack in [
                    ('baseline', NormalParticipant),
                    ('signflip', SignFlipper),
                    ('labelflip', LabelFlipper),
                    ('constantdirection', ConstantDirection),
                    ('delayedgrads', DelayedGradientAttacker),
                ]:
                    print(f'seed = {seed}')
                    print(f'num_byzantines = {num_byzantines}')
                    print(f'attack_every = {attack_every}')
                    print(f'tau = {tau}')
                    print(f'attack_start = {attack_start}')
                    print(f'attack = {name}')

                    LOGS_FOLDER = f'{name}_at_{attack_start}_attackers_{num_byzantines}_every_{attack_every}'
                    print(name, attack, LOGS_FOLDER)

                    class config:
                        MODEL = ResNet18
                        AUGMENT_DATA = True

                        GLOBAL_SEED = seed
                        NUM_WORKERS = 16
                        MAX_EPOCHS_PER_WORKER = 4
                        BATCH_SIZE_PER_WORKER = 8
                        EVAL_BATCH_SIZE = 16
                        EVAL_EVERY = 50

                        CCLIP_TAU = tau
                        CCLIP_MAX_ITERS = 500
                        CCLIP_EPS = 1e-6

                        BASE_LR = 0.01
                        MOMENTUM = 0.9
                        NESTEROV = True
                        WEIGHT_DECAY = 5e-4
                        COSINE_T_MAX_RATE = 1.0

                        BENIGN_PARTICIPANT = NormalParticipant

                        NUM_BYZANTINES = 0 if name == 'baseline' else num_byzantines
                        BYZANTINE_PARTICIPANT = partial(
                            attack, ban_prob=1. / NUM_WORKERS * ((NUM_WORKERS - NUM_BYZANTINES) / NUM_WORKERS),
                            attack_start=1000, direction_seed=seed, attack_every=attack_every)
                        # note: attack_start parameter above is relative to the start checkpoint, so attackers that init at 9k will attack after 10k steps 
                        BYZANTINE_IDS = random.Random(GLOBAL_SEED).sample(range(1, NUM_WORKERS), NUM_BYZANTINES)

                        EARLY_STOP_STEPS = 5_100 if attack_start == '1k' else 14_100
                        EXP_NAME = f"resnet_decentclip_tau{CCLIP_TAU}_max_iters{CCLIP_MAX_ITERS}_seed{GLOBAL_SEED}"

                    processes = []
                    for rank in range(config.NUM_WORKERS):
                        p = Process(target=run_worker, args=(config, rank), daemon=True)
                        p.start()
                        processes.append(p)

                    for p in processes:
                        p.join()