## ALBERT Experiments for the paper "Secure Distributed Training at Scale"

This notebook implements the ALBERT experiments for the paper. It uses a version of the [hivemind](https://github.com/learning-at-home/hivemind) library modified so that some peers may be programmed to become Byzantine and perform various types of attacks on the training process. The resulting plots are posted to the [Wandb](https://wandb.ai) service.

### Step 1: Dataset Preparation

1. Run the following commands locally to prepare the Wikitext103 dataset for the training:

```
git clone https://github.com/learning-at-home/hivemind
cd hivemind
pip install -e .
cd examples/albert
pip install -r requirements.txt
python tokenize_wikitext103.py
```

2. Upload an archive with preprocessed data to some URL accessible to the AWS machines (e.g. `https://hivemind-data.s3.us-east-2.amazonaws.com/wikitext103_preprocessed.tar`).

### Step 2: Setting General Parameters

In this part of the notebook, you can set general parameters of your AWS configuration.

It also defines `kill_instances(experiment_name)` function. You can use it to manually stop all instances related to your experiment if necessary.

In [None]:
import boto3
# the code below assumes that you configure boto3 with your AWS account
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html
ec2 = boto3.resource('ec2')
client = boto3.client('ec2')

In [None]:
data_path = "https://"  # TODO: Set URL of the preprocessed wikitext103 dataset from the previous step
aws_key_name = ""       # TODO: Update with your aws key name
subnet = ""             # TODO: Update with your subnet name or skip entirely
security_group = ""     # TODO: Set security group
WandB_API_key = ""      # TODO: Set wandb.ai API key

image_id = "ami-0db67995cd75f5a9f"
coordinator_type = "r5.large"
dht_port = 31337
num_workers = 16
n_attackers = 7

# The lower bound of the probability that a Byzantine will be revealed on the current step
check_proba = (num_workers - n_attackers) / num_workers * 1 / (num_workers - 1)
print(f'check_proba = {check_proba}')

use_internal_routing = True  # Whether to use AWS internal (private) IP addresses
# Note: Setting use_internal_routing = False will cause all nodes to communicate over the public network.
#       This may incur additional charges. Change it only if you know what you're doing!

In [None]:
def kill_instances(experiment_name):
    existing_instances = ec2.instances.filter(Filters=[
        {'Name': 'instance-state-name', 'Values': ['running']},
        {'Name': 'tag:experiment', 'Values': [experiment_name]},
    ])
    ins = list(existing_instances)
    private_ips = []
    if ins:
        print(f"Already running {experiment_name}: {ins}")
        print(len(ins))
        for i in ins:
            private_ips.append(i.private_ip_address)
            print(i.public_ip_address, i.private_ip_address)
    
    # to remove all instances and spot requests, run this:
    existing_instances.terminate()
    requests_to_shutdown = []
    for request in client.describe_spot_instance_requests()['SpotInstanceRequests']:
        if request['State'] == 'active' and \
                any(tag['Key'] == 'experiment' and tag['Value'] == experiment_name for tag in request['Tags']):
            requests_to_shutdown.append(request['SpotInstanceRequestId'])
    if requests_to_shutdown:
        client.cancel_spot_instance_requests(
            SpotInstanceRequestIds=requests_to_shutdown)
    print('Instances stopped')

### Step 3: Define Coordinator Setup Code

A coordinator is an instance that welcomes new peers into a decentralized training run.

In [None]:
get_ip_cmd = "export IP=$(ifconfig eth0 | grep -Eo 'inet (addr:)?([0-9]*\.){3}[0-9]*' | grep -Eo '([0-9]*\.){3}[0-9]*')"

In [None]:
coordinator_script = f'''#!/bin/bash -ex
exec > >(tee /var/log/user-command.log|logger -t user-data -s 2>/dev/console) 2>&1

# note: we configure rsyslog to forward logs from all trainers
sudo sh -c 'cat <<"EOF" >> /etc/rsyslog.conf
$ModLoad imudp
$UDPServerRun 514

$ModLoad imtcp
$InputTCPServerRun 514

$FileCreateMode 0644
$DirCreateMode 0755
$Umask 0022

$template RemoteLogs,"/var/log/rsyslog/%HOSTNAME%.log"
*.*  ?RemoteLogs
& ~
EOF'
sudo systemctl restart rsyslog

{get_ip_cmd if use_internal_routing else ''}

# NOTE: docker run must be called without --it as there is no tty
# check machine's /var/log/user-command.log for details

docker run --name trainer_run --ipc=host --net=host learningathome/hivemind:master bash -c """
set -euxo pipefail

rm -rf hivemind
git clone https://github.com/yandex-research/btard
cd btard/albert/hivemind

pip install -e .
pip install transformers==4.5.1
cd ../experiments


ulimit -n 4096


sh -c 'cat <<"EOF" >> ~/.netrc
machine api.wandb.ai
  login user
  password {WandB_API_key}
EOF'


if [ "%initial_state_url%" != "" ]; then
    wget -q "%initial_state_url%" -O initial_state.pickle
fi

HIVEMIND_TAU=%tau% HIVEMIND_THREADS=256 WANDB_ENTITY=learning-at-home python ./run_first_peer.py --dht_listen_on [::]:{dht_port} {'--address $IP' if use_internal_routing else ''} \
 --experiment_prefix %experiment_name% --wandb_project Runs \
 --compression NONE --metadata_expiration 180 --averaging_timeout 60 --averaging_expiration 10 \
 $(if [ "%initial_step%" != "0" ]; then echo --initial_state_path initial_state.pickle; fi)
"""
'''

In [None]:
def create_coordinator(tau, experiment_name, initial_step=0, initial_state_url=''):
    coordinator, = ec2.create_instances(
        ImageId=image_id, InstanceType=coordinator_type,
        MinCount=1, MaxCount=1,
        SecurityGroupIds=[security_group], SubnetId=subnet,
        KeyName=aws_key_name,
        UserData=coordinator_script
            .replace('%tau%', str(tau))
            .replace('%initial_state_url%', initial_state_url)
            .replace('%initial_step%', str(initial_step))
            .replace('%experiment_name%', experiment_name),
        TagSpecifications=[{'ResourceType': 'instance', 'Tags': [
            {'Key':'experiment', 'Value': experiment_name},
            {'Key':'role', 'Value': 'first_peer'}
        ]}]
    )
    coordinator.wait_until_running()
    coordinator, = list(ec2.instances.filter(InstanceIds=[coordinator.id]))

    print('Created coordinator:', coordinator.private_ip_address, coordinator.public_ip_address)

    if use_internal_routing:
        coordinator_ip = coordinator.private_ip_address
    else:
        coordinator_ip = coordinator.public_ip_address

    coordinator_endpoint = f"{coordinator_ip}:{dht_port}"
    print('coordinator_endpoint =', coordinator_endpoint)
    
    return {'ip': coordinator_ip, 'endpoint': coordinator_endpoint}

### Step 4: Define Worker Setup Code

Workers are preemptible GPU instances that run compute gradients and perform averaging. In this example, each worker is a single tesla T4 instance.

You will typically run 16 workers. Some of them may be Byzantine.

In [None]:
worker_script = f'''#!/bin/bash -ex
exec > >(tee /var/log/user-command.log|logger -t user-data -s 2>/dev/console) 2>&1

set -euxo pipefail
cd ~

sudo sh -c 'cat <<"EOF" >> /etc/rsyslog.conf

user.* @@%coordinator_ip%:514

EOF'
sudo systemctl restart rsyslog


{get_ip_cmd if use_internal_routing else ''}


docker run --name hivemind_run --gpus all --ipc=host --net=host learningathome/hivemind:master bash -c """

rm -rf hivemind
git clone https://github.com/yandex-research/btard
cd btard/albert/hivemind

pip install -e .
pip install transformers==4.5.1
cd ../experiments


mkdir -p ~/data
wget -qO- {data_path} | tar xzf -


sh -c 'cat <<"EOF" >> ~/.netrc
machine api.wandb.ai
  login user
  password {WandB_API_key}
EOF'


ulimit -n 4096

HIVEMIND_TAU=%tau% ATTACK_TYPE=%attack_type% ATTACK_START=%attack_start% CHECK_PROBA={check_proba} \
  DIRECTION_SEED=%seed% \
  WANDB_PROJECT=%experiment_name% WANDB_ENTITY=learning-at-home WANDB_WATCH=false \
  HIVEMIND_THREADS=256 python run_trainer.py \
  --output_dir ./outputs --overwrite_output_dir \
  {'--endpoint $IP'+':*' if use_internal_routing else ''} \
  --logging_dir ./logs --logging_first_step --logging_steps 100 \
  --initial_peers %coordinator_endpoint%  --run_name aws_worker \
  --experiment_prefix %experiment_name% --seed %seed% --compression NONE --metadata_expiration 180 \
  --averaging_timeout 60 --averaging_expiration 10 --statistics_expiration 60
"""
'''

In [None]:
def create_instance(worker_type, attack_type, tau, experiment_name, coordinator_ip, coordinator_endpoint, seed,
                    attack_start):
    new_worker, = ec2.create_instances(
    ImageId=image_id, InstanceType=worker_type,
    MinCount=1, MaxCount=1,
    UserData=worker_script
        .replace('%attack_type%', attack_type)
        .replace('%attack_start%', str(attack_start))
        .replace('%tau%', str(tau))
        .replace('%tau_underscore%', str(tau).replace('.', '_'))
        .replace('%experiment_name%', experiment_name)
        .replace('%coordinator_ip%', coordinator_ip)
        .replace('%coordinator_endpoint%', coordinator_endpoint)
        .replace('%seed%', str(seed)),
    SecurityGroupIds=[security_group], SubnetId=subnet, 
    KeyName=aws_key_name,
    InstanceMarketOptions={
        "MarketType": "spot",
        "SpotOptions": {
            "SpotInstanceType": "one-time",
            "InstanceInterruptionBehavior": "terminate"
        }
    },
    TagSpecifications=[{'ResourceType': 'instance', 'Tags': [
        {'Key':'experiment', 'Value': experiment_name},
        {'Key':'role', 'Value': 'gpu_worker'}
    ]}, {'ResourceType': 'spot-instances-request', 'Tags': [
        {'Key':'experiment', 'Value': experiment_name},
        {'Key':'role', 'Value': 'gpu_worker'}
    ]}],)
    return new_worker

In [None]:
import time
import traceback

def run_workers(tau, experiment_name, coordinator_ip, coordinator_endpoint,
                n_attackers, time_limit=3 * 3600, intended_attack='NONE', **kwargs):
    stop_time = time.time() + time_limit
    while time.time() < stop_time:
        existing_instances = list(ec2.instances.filter(Filters=[
            {'Name': 'instance-state-name', 'Values': ['running']},
            {'Name': 'tag:experiment', 'Values': [experiment_name]},
        ]))

        count_needed = num_workers + 1 - len(existing_instances)
        if count_needed > 0:
            attack_type = intended_attack if count_needed > num_workers - n_attackers else 'NONE'
            
            print(f"Need {count_needed} more workers. Trying to spawn one")
            instance_types = ['g4dn.2xlarge']
            for i, worker_type in enumerate(instance_types):
                try:
                    new_worker = create_instance(
                        worker_type, attack_type, tau, experiment_name, coordinator_ip, coordinator_endpoint, **kwargs)
                    new_worker.wait_until_running()
                    new_worker, = list(ec2.instances.filter(InstanceIds=[new_worker.id]))
                    print("CREATED ONE WORKER!", worker_type, attack_type,
                          new_worker.public_ip_address, new_worker.private_ip_address)
                    break
                except Exception as e:
                    print('Failed:', worker_type, e)
                    traceback.print_exc()
                    
        time.sleep(30)

### Step 5: Running the Training from Scratch

This steps runs the training from scratch, so you will be able to collect checkpoints for the steps 950 and 4950, where we are going to test different attack types. The training from scratch is expected to run for $\approx 3$ days. It should be done for each value of $\tau$ (the CenteredClip parameter) separately.

Once the workers have started (7-10 minutes into training), you will be able to see the training progress in your Wandb account:

<img src="images/scratch_wandb.png" width="500">

In [None]:
tau = 0.125  # TODO: Set tau for CenteredClip
seed = 1337
attack_start = 0
intended_attack = 'NONE'

experiment_name = "baseline_tau_{tau}"

try:
    print(f'\n[*] {experiment_name}: Creating coordinator...')
    while True:
        try:
            coordinator_info = create_coordinator(tau, experiment_name)
            break
        except Exception as e:
            print('[-] Failed to create coordinator:', e)
            traceback.print_exc()
            time.sleep(30)
    time.sleep(5 * 60)

    print(f'\n[*] {experiment_name}: Running workers...')
    run_workers(tau, experiment_name, coordinator_info['ip'], coordinator_info['endpoint'], n_attackers,
                time_limit=7 * 24 * 3600, intended_attack=intended_attack, seed=seed, attack_start=attack_start)
finally:
    print(f'\n[*] {experiment_name}: Stopping instances...')
    kill_instances(experiment_name)

You should wait until the training reaches $\approx 5000$ steps:

<img src="images/convergence_wandb.png" width="500">

Then you can obtain the model checkpoints dumped for the steps 950 and 4950:

1. Connect to any worker using SSH and its public IP from the previous cell's output;
2. Attach to the Docker container: `docker exec -it hivemind_run bash`;
3. Download the checkpoints (`*.pickle` files) from the hivemind directory inside Docker to your local machine;
4. Upload them to some URL accessible to the AWS machines.

After that, you can wait until the model converges or move on to testing various attack types.

### Step 6: Simulating Various Attack Types

Now, you can load one of the checkpoints and simulate different attack types (sign flipping, label flipping, random direction, or the absence of any attack) at this stage of the training. One attack simulation runs for 4 hours and makes $\approx 250$ training steps. It makes sense to run one simulation multiple times with different $seed$ values to learn the possible loss variance and observe worst case scenarios for the current attack.

Similarly to the previous step, each simulation generates a separate Wandb experiment. The loss curve will demonstrate that this time the experiment begins midway through training and the attack temporarily increases the loss value after the step number `attack_start`.

On the other figure, you will see that all Byzantine peers will be revealed and banned over the time. In real BTARD-SGD, this would happen when a randomly chosen validating peer will discover that a Byzantine has sent fake gradients. However, in our simulation, a Byzantine just bans itself with the probability `check_proba` (that is chosen to be the lower bound of the probability that it will be revealed in real BTARD-SGD with one validator).

<img src="images/attack_wandb.png" width="750">

In [None]:
initial_state_url = "https://"     # TODO: Set the URL to the checkpoint obtained on the previous step
seed = 1337                        # TODO: Set random seed
tau = 0.125                        # TODO: Set tau for CenteredClip
initial_step = 950                 # TODO: Set initial step (which state is to be loaded)
attack_start = 1000                # TODO: Set the step where the attack starts
intended_attack = 'SIGN_FLIPPING'  # TODO: Set the attack_type
                                   #       (one of NONE, SIGN_FLIPPING, LABEL_FLIPPING, or CONSTANT_DIRECTION)

experiment_name = f"{intended_attack.lower()}_at_{initial_step}_tau_{tau}_seed_{seed}"

try:
    print(f'\n[*] {experiment_name}: Creating coordinator...')
    while True:
        try:
            coordinator_info = create_coordinator(tau, experiment_name,
                                                  initial_step=initial_step, initial_state_url=initial_state_url)
            break
        except Exception as e:
            print('[-] Failed to create coordinator:', e)
            traceback.print_exc()
            time.sleep(30)
    time.sleep(5 * 60)

    print(f'\n[*] {experiment_name}: Running workers...')
    run_workers(tau, experiment_name, coordinator_info['ip'], coordinator_info['endpoint'], n_attackers,
                time_limit=4 * 3600, intended_attack=intended_attack, seed=seed, attack_start=attack_start)
finally:
    print(f'\n[*] {experiment_name}: Stopping instances...')
    kill_instances(experiment_name)