In [None]:
import boto3
# the code below assumes that you configure boto3 with your AWS account
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/quickstart.html
ec2 = boto3.resource('ec2')
client = boto3.client('ec2')

In [None]:
experiment_name = "Your_name"
# ^-- must be unique per experiment


coordinator_type = "r5.large"
dht_port = 31337
worker_type = "g4dn.2xlarge"
num_workers = 16 ## number of GPU instances in your experiment
num_aux = 4 ## number of auxiliary CPU instances (maybe zero)

bands = [{"up":200,"down":200}]*4+[{"up":100,"down":100}]*8+[{"up":50,"down":50}]*4  ## bandwidth of GPU peers, mbps

image_id = "ami-0db67995cd75f5a9f"
aws_key_name = "aws"  ## update with your aws key name
subnet = ""  ## update with your subnet name or skip entirely
security_group = ""  ## you guessed it
data_path = "" ## path to an archive with wikitext103 dataset
repo_path = "" ## path to the repo with code of our `src` library

In [None]:
# check that the experiment name is unique.
# disable this if you want to add more instances to an existing experiment
existing_instances = ec2.instances.filter(Filters=[
    {'Name': 'instance-state-name', 'Values': ['running']},
    {'Name': 'tag:experiment', 'Values': [experiment_name]},
])
ins = list(existing_instances)
if ins:
    print(f"Already running {experiment_name}: {ins}")
    print(len(ins))
    for i in ins:
        print(i.public_ip_address, i.private_ip_address)

In [None]:
# to remove all instances and spot requests, uncomment and run this:
# existing_instances.terminate()
# requests_to_shutdown = []
# for request in client.describe_spot_instance_requests()['SpotInstanceRequests']:
#     if request['State'] == 'active' and any(
#         tag['Key'] == 'experiment' and tag['Value'] == experiment_name
#         for tag in request['Tags']):
#         requests_to_shutdown.append(request['SpotInstanceRequestId'])
# if requests_to_shutdown:
#     client.cancel_spot_instance_requests(
#         SpotInstanceRequestIds=requests_to_shutdown)

### Stage 1: run coordinator

Coordinator is an instance that welcomes new peers into a decentralized training run. If coordinator is down, new peers can still join by initializing with one of the existing peers.

In [None]:
WandB_API_key = ""  ## Your key

In [None]:
coordinator_script = f'''#!/bin/bash -ex
exec > >(tee /var/log/user-command.log|logger -t user-data -s 2>/dev/console) 2>&1



git clone {repo_path}
cd DeDLOC

pip install -e ./src
cd albert


ulimit -n 4096


sh -c 'cat <<"EOF" >> ~/.netrc
machine api.wandb.ai
  login user
  password {WandB_API_key}
EOF'

HIVEMIND_THREADS=128 python ./run_first_peer.py --dht_listen_on [::]:{dht_port} \
 --experiment_prefix {experiment_name} --wandb_project Runs 
'''

In [None]:
coordinator, = ec2.create_instances(
    ImageId=image_id, InstanceType=coordinator_type,
    MinCount=1, MaxCount=1,
    SecurityGroupIds=[security_group], SubnetId=subnet,
    KeyName=aws_key_name, UserData=coordinator_script,
    TagSpecifications=[{'ResourceType': 'instance', 'Tags': [
        {'Key':'experiment', 'Value': experiment_name},
        {'Key':'role', 'Value': 'first_peer'}
    ]}]
)
coordinator.wait_until_running()
coordinator, = list(ec2.instances.filter(InstanceIds=[coordinator.id]))


coordinator_ip = coordinator.public_ip_address

coordinator_endpoint = f"{coordinator_ip}:{dht_port}"
print(coordinator_endpoint)

### Stage 1.5: run auxiliary CPU peers

Auxiliary peers are CPU instances that take gradients from workers and perform averaging.

In [None]:
aux_script = f'''#!/bin/bash -ex
exec > >(tee /var/log/user-command.log|logger -t user-data -s 2>/dev/console) 2>&1



sudo yum install tc -y


git clone https://github.com/magnific0/wondershaper.git
cd wondershaper
sudo ./wondershaper -a eth0 -u {500*1024} -d {500*1024}
cd ..



git clone {repo_path}
cd DeDLOC

pip install -e ./src
cd albert


mkdir -p ~/data
wget -qO- {data_path} | tar xzf -



ulimit -n 4096


sh -c 'cat <<"EOF" >> ~/.netrc
machine api.wandb.ai
  login user
  password {WandB_API_key}
EOF'


WANDB_PROJECT={experiment_name} HIVEMIND_THREADS=128 python run_aux.py \
  --output_dir ./outputs --overwrite_output_dir \
  --logging_dir ./logs --logging_first_step --logging_steps 100 \
  --initial_peers {coordinator_endpoint} \
  --experiment_prefix {experiment_name} --seed 42 --averaging_timeout 120 --fp16 False --bandwidth 1000
'''

In [None]:
for _ in range(num_aux):
    aux, = ec2.create_instances(
        ImageId=image_id, InstanceType=coordinator_type,
        MinCount=1, MaxCount=1,
        SecurityGroupIds=[security_group], SubnetId=subnet,
        KeyName=aws_key_name, UserData=aux_script,
        TagSpecifications=[{'ResourceType': 'instance', 'Tags': [
            {'Key':'experiment', 'Value': experiment_name},
            {'Key':'role', 'Value': 'aux_peer'}
        ]}]
    )
    aux.wait_until_running()
    aux, = list(ec2.instances.filter(InstanceIds=[aux.id]))

    print(aux.private_ip_address, aux.public_ip_address)

### Stage 2: run workers

Workers are preemptible GPU instances that run compute gradients and perform averaging. In this example, each worker is a single tesla T4 instance.

In [None]:
def gen_w(params):
    worker_script = f'''#!/bin/bash -ex
exec > >(tee /var/log/user-command.log|logger -t user-data -s 2>/dev/console) 2>&1


sudo yum install tc -y


git clone https://github.com/magnific0/wondershaper.git
cd wondershaper
sudo ./wondershaper -a eth0 -u {params["up"]*1024} -d {params["down"]*1024}
cd ..


git clone {repo_path}
cd DeDLOC

pip install -e ./src
cd albert


mkdir -p ~/data
wget -qO- {data_path} | tar xzf -


sh -c 'cat <<"EOF" >> ~/.netrc
machine api.wandb.ai
  login user
  password {WandB_API_key}
EOF'


ulimit -n 4096

WANDB_PROJECT={experiment_name} HIVEMIND_THREADS=128 python run_trainer.py \
  --output_dir ./outputs --overwrite_output_dir \
  --logging_dir ./logs --logging_first_step --logging_steps 100 \
  --initial_peers {coordinator_endpoint} --run_name aws_worker \
  --experiment_prefix {experiment_name} --seed 42 --client_mode False --averaging_timeout 120  --bandwidth {params["up"]}
'''
    return worker_script

In [None]:
def create_instance(worker_type, i):
    new_worker, = ec2.create_instances(
    ImageId=image_id, InstanceType=worker_type,
    MinCount=1, MaxCount=1,
    UserData=gen_w(s[i]),
    SecurityGroupIds=[security_group], SubnetId=subnet, 
    KeyName=aws_key_name,
    InstanceMarketOptions={
        "MarketType": "spot",
        "SpotOptions": {
            "SpotInstanceType": "one-time",
            "InstanceInterruptionBehavior": "terminate"
        }
    },
    TagSpecifications=[{'ResourceType': 'instance', 'Tags': [
        {'Key':'experiment', 'Value': experiment_name},
        {'Key':'role', 'Value': 'gpu_worker'},
        {'Key':'type', 'Value': str(i)}
    ]}, {'ResourceType': 'spot-instances-request', 'Tags': [
        {'Key':'experiment', 'Value': experiment_name},
        {'Key':'role', 'Value': 'gpu_worker'}
    ]}],)
    return new_worker

In [None]:
import time
while True:
    existing_instances = list(ec2.instances.filter(Filters=[
        {'Name': 'instance-state-name', 'Values': ['running']},
        {'Name': 'tag:experiment', 'Values': [experiment_name]},
    ]))
    count_needed = num_workers + 1 - len(existing_instances)
    if count_needed > 0:
        for i in range(num_workers):
            for ins in existing_instances:
                for tag in ins.tags:
                    if tag["Value"] == str(i):
                        break
        try:
            worker_type = "g4dn.xlarge"
            print(f"Need {count_needed} more workers. Trying to spawn one")
            new_worker = create_instance(worker_type, i)
            new_worker.wait_until_running()
            new_worker, = list(ec2.instances.filter(InstanceIds=[new_worker.id]))
            print("CREATED ONE WORKER!", worker_type, 
                  new_worker.public_ip_address, new_worker.private_ip_address)
        except BaseException as e:
            try:
                worker_type = "g4dn.2xlarge"
                new_worker = create_instance(worker_type, i)
                new_worker.wait_until_running()
                new_worker, = list(ec2.instances.filter(InstanceIds=[new_worker.id]))
                print("CREATED ONE WORKER!", worker_type, 
                      new_worker.public_ip_address, new_worker.private_ip_address)
            except BaseException as e:
                print("FAILED", e)
    else:
        print("Enough workers already, check back in 60s...")
        time.sleep(60)
        break