Skip to content

Commit

Permalink
added ray as a distributor
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaishaal Shankar committed Jan 22, 2023
1 parent e559206 commit 6183e5a
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 1 deletion.
67 changes: 67 additions & 0 deletions examples/ray_example/cluster_minimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: minimal
min_workers: 0
max_workers: 10
upscaling_speed: 1.0
available_node_types:
ray.head.default:
resources: {}
node_config:
ImageId: ami-0ea1c7db66fee3098
InstanceType: m5.24xlarge
# if you have an IamInstanceProfile fill it out here...
#IamInstanceProfile:
# Arn: <instance_profile_arn>
ray.worker.default:
min_workers: 0
max_workers: 500
node_config:
ImageId: ami-0ea1c7db66fee3098
InstanceType: m5.24xlarge
InstanceMarketOptions:
MarketType: spot
# if you have an IamInstanceProfile fill it out here...
#IamInstanceProfile:
# Arn: <instance_profile_arn>

# Cloud-provider specific configuration.
provider:
type: aws
region: us-east-1

initialization_commands:
- wget https://secure.nic.cz/files/knot-resolver/knot-resolver-release.deb
- sudo dpkg -i knot-resolver-release.deb
- sudo apt update
- sudo apt install -y knot-resolver
- sudo sh -c 'echo `hostname -I` `hostname` >> /etc/hosts'
- sudo sh -c 'echo nameserver 127.0.0.1 > /etc/resolv.conf'
- sudo systemctl stop systemd-resolved
- sudo systemctl start kresd@1.service
- sudo systemctl start kresd@2.service
- sudo systemctl start kresd@3.service
- sudo systemctl start kresd@4.service
- sudo systemctl start kresd@5.service
- sudo systemctl start kresd@6.service
- sudo systemctl start kresd@7.service
- sudo systemctl start kresd@8.service
- sudo apt-get install ffmpeg libsm6 libxext6 -y

setup_commands:
- wget https://repo.anaconda.com/miniconda/Miniconda3-py39_22.11.1-1-Linux-x86_64.sh -O miniconda.sh
- bash ~/miniconda.sh -f -b -p miniconda3/
- echo 'export PATH="$HOME/miniconda3/bin/:$PATH"' >> ~/.bashrc
# if you have AWS CREDS fill them out here
#- echo 'export AWS_ACCESS_KEY_ID=<AWS_KEY>' >> ~/.bashrc
#- echo 'export AWS_SECRET_ACCESS_KEY=<AWS_SECRET_KEY>' >> ~/.bashrc
- pip install --upgrade pip setuptools wheel
- pip install ray
- pip uninstall -y img2dataset
- pip install git+https://github.com/vaishaal/img2dataset.git@7aadba42f8008106bd38475e06e78e79dfe4bbeb
- pip install opencv-python --upgrade
- wandb login ead6ddc201a45d8e0d9b6e76220e4faf18178820
- pip install s3fs==2022.11.0
- pip install botocore==1.27.59

head_setup_commands: []

49 changes: 49 additions & 0 deletions examples/ray_example/ray_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
import time
from collections import Counter

import ray
from img2dataset import download

import argparse




@ray.remote
def main(args):
download(
processes_count=1,
thread_count=32,
retries=0,
timeout=10,
url_list=args.url_list,
image_size=512,
resize_only_if_bigger=True,
resize_mode="keep_ratio_largest",
skip_reencode=True,
output_folder=args.out_folder,
output_format="webdataset",
input_format="parquet",
url_col="url",
caption_col="alt",
enable_wandb=True,
subjob_size=48*120*2,
number_sample_per_shard=10000,
distributor="ray",
oom_shard_count=8,
compute_hash="sha256",
save_additional_columns=["uid"]
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url_list")
parser.add_argument("--out_folder")
args = parser.parse_args()
ray.init(address="localhost:6379")
main(args)




19 changes: 19 additions & 0 deletions examples/ray_example/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Parallelizing Img2Dataset using Ray
If you do not want to set up a PySpark cluster, you can also set up a ray cluster, functionally they are
close to the same but ray handles a larger amount of tasks better and doesn't have the "staged" nature of
Spark which is great if you have a large queue of tasks and don't want to be vulnerable to the stragglers in each batch.
The tooling to set up a Ray cluster on AWS is slightly better at the time of writing this document (Jan 2023)

## Instructions for running a large img2dataset job on a ray cluster on AWS
First install ray:
``` pip install ray ```

If you are on AWS you can spin up a ray cluster this way:

``` ray up cluster_minimal.yaml ```

Then you can run your job:
```ray submit cluster_minmal.yaml ray_example.py -- --url_list <url_list> --out_folder <out_folder>```

Using the above code I was able to achieve a maximum download rate of 220,000 images/second on a cluster of 100 m5.24xlarge (9600 cores).

23 changes: 23 additions & 0 deletions img2dataset/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from tqdm import tqdm



def retrier(runf, failed_shards, max_shard_retry):
# retry failed shards max_shard_retry times
for i in range(max_shard_retry):
Expand Down Expand Up @@ -65,6 +66,28 @@ def run(gen):

retrier(run, failed_shards, max_shard_retry)

try:
import ray #pylint: disable=import-outside-toplevel

@ray.remote
def ray_download(downloader, shards):
status, row = downloader(ray.get(shards))
return status, row


def ray_distributor(processes_count, downloader, reader, _, __): # type: ignore
ret = []
count = 0
for task in reader:
count += 1
ret.append(ray_download.remote(downloader, task))
ray.get(ret)

except ModuleNotFoundError as e:
def ray_distributor(processes_count, downloader, reader, subjob_size, max_shard_retry): # type: ignore
return None



@contextmanager
def _spark_session(processes_count: int):
Expand Down
4 changes: 3 additions & 1 deletion img2dataset/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from .reader import Reader
from .downloader import Downloader
from .distributor import multiprocessing_distributor, pyspark_distributor
from .distributor import multiprocessing_distributor, pyspark_distributor, ray_distributor
import fsspec
import sys
import signal
Expand Down Expand Up @@ -244,6 +244,8 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
distributor_fn = multiprocessing_distributor
elif distributor == "pyspark":
distributor_fn = pyspark_distributor
elif distributor == "ray":
distributor_fn = ray_distributor
else:
raise ValueError(f"Distributor {distributor} not supported")

Expand Down

0 comments on commit 6183e5a

Please sign in to comment.