Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding ray as a distributor #272

Merged
merged 15 commits into from
Aug 20, 2023
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,7 @@ jobs:
- name: Unit tests
run: |
source .env/bin/activate
ray start --disable-usage-stats
ray start --address='127.0.0.1:6379'
make test

67 changes: 67 additions & 0 deletions examples/ray_example/cluster_minimal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# An unique identifier for the head node and workers of this cluster.
cluster_name: minimal
min_workers: 0
max_workers: 10
upscaling_speed: 1.0
available_node_types:
ray.head.default:
resources: {}
node_config:
ImageId: ami-0ea1c7db66fee3098
InstanceType: m5.24xlarge
# if you have an IamInstanceProfile fill it out here...
#IamInstanceProfile:
# Arn: <instance_profile_arn>
ray.worker.default:
min_workers: 0
max_workers: 500
node_config:
ImageId: ami-0ea1c7db66fee3098
InstanceType: m5.24xlarge
InstanceMarketOptions:
MarketType: spot
# if you have an IamInstanceProfile fill it out here...
#IamInstanceProfile:
# Arn: <instance_profile_arn>

# Cloud-provider specific configuration.
provider:
type: aws
region: us-east-1

initialization_commands:
- wget https://secure.nic.cz/files/knot-resolver/knot-resolver-release.deb
- sudo dpkg -i knot-resolver-release.deb
- sudo apt update
- sudo apt install -y knot-resolver
- sudo sh -c 'echo `hostname -I` `hostname` >> /etc/hosts'
- sudo sh -c 'echo nameserver 127.0.0.1 > /etc/resolv.conf'
- sudo systemctl stop systemd-resolved
- sudo systemctl start kresd@1.service
- sudo systemctl start kresd@2.service
- sudo systemctl start kresd@3.service
- sudo systemctl start kresd@4.service
- sudo systemctl start kresd@5.service
- sudo systemctl start kresd@6.service
- sudo systemctl start kresd@7.service
- sudo systemctl start kresd@8.service
- sudo apt-get install ffmpeg libsm6 libxext6 -y

setup_commands:
- wget https://repo.anaconda.com/miniconda/Miniconda3-py39_22.11.1-1-Linux-x86_64.sh -O miniconda.sh
- bash ~/miniconda.sh -f -b -p miniconda3/
- echo 'export PATH="$HOME/miniconda3/bin/:$PATH"' >> ~/.bashrc
# if you have AWS CREDS fill them out here
#- echo 'export AWS_ACCESS_KEY_ID=<AWS_KEY>' >> ~/.bashrc
#- echo 'export AWS_SECRET_ACCESS_KEY=<AWS_SECRET_KEY>' >> ~/.bashrc
- pip install --upgrade pip setuptools wheel
- pip install ray
- pip uninstall -y img2dataset
- pip install git+https://github.com/vaishaal/img2dataset.git@7aadba42f8008106bd38475e06e78e79dfe4bbeb
- pip install opencv-python --upgrade
- wandb login ead6ddc201a45d8e0d9b6e76220e4faf18178820
- pip install s3fs==2022.11.0
- pip install botocore==1.27.59

head_setup_commands: []

49 changes: 49 additions & 0 deletions examples/ray_example/ray_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys
import time
from collections import Counter

import ray
from img2dataset import download

import argparse




@ray.remote
def main(args):
download(
processes_count=1,
thread_count=32,
retries=0,
timeout=10,
url_list=args.url_list,
image_size=512,
resize_only_if_bigger=True,
resize_mode="keep_ratio_largest",
skip_reencode=True,
output_folder=args.out_folder,
output_format="webdataset",
input_format="parquet",
url_col="url",
caption_col="alt",
enable_wandb=True,
subjob_size=48*120*2,
number_sample_per_shard=10000,
distributor="ray",
oom_shard_count=8,
compute_hash="sha256",
save_additional_columns=["uid"]
)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--url_list")
parser.add_argument("--out_folder")
args = parser.parse_args()
ray.init(address="localhost:6379")
main(args)




19 changes: 19 additions & 0 deletions examples/ray_example/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Parallelizing Img2Dataset using Ray
If you do not want to set up a PySpark cluster, you can also set up a ray cluster, functionally they are
close to the same but ray handles a larger amount of tasks better and doesn't have the "staged" nature of
Spark which is great if you have a large queue of tasks and don't want to be vulnerable to the stragglers in each batch.
The tooling to set up a Ray cluster on AWS is slightly better at the time of writing this document (Jan 2023)

## Instructions for running a large img2dataset job on a ray cluster on AWS
First install ray:
``` pip install ray ```

If you are on AWS you can spin up a ray cluster this way:

``` ray up cluster_minimal.yaml ```

Then you can run your job:
```ray submit cluster_minmal.yaml ray_example.py -- --url_list <url_list> --out_folder <out_folder>```

Using the above code I was able to achieve a maximum download rate of 220,000 images/second on a cluster of 100 m5.24xlarge (9600 cores).

27 changes: 25 additions & 2 deletions img2dataset/distributor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def multiprocessing_distributor(processes_count, downloader, reader, _, max_shar

def run(gen):
failed_shards = []
for (status, row) in tqdm(process_pool.imap_unordered(downloader, gen)):
for status, row in tqdm(process_pool.imap_unordered(downloader, gen)):
if status is False:
failed_shards.append(row)
return failed_shards
Expand Down Expand Up @@ -56,7 +56,7 @@ def run(gen):
failed_shards = []
for batch in batcher(gen, subjob_size):
rdd = spark.sparkContext.parallelize(batch, len(batch))
for (status, row) in rdd.map(downloader).collect():
for status, row in rdd.map(downloader).collect():
if status is False:
failed_shards.append(row)
return failed_shards
Expand All @@ -66,6 +66,29 @@ def run(gen):
retrier(run, failed_shards, max_shard_retry)


try:
import ray # pylint: disable=import-outside-toplevel

@ray.remote
def ray_download(downloader, shards):
status, row = downloader(shards)
return status, row

def ray_distributor(processes_count, downloader, reader, _, max_shard_retry): # type: ignore
# pylint: disable=unused-argument
rom1504 marked this conversation as resolved.
Show resolved Hide resolved
ret = []
count = 0
for task in reader:
count += 1
ret.append(ray_download.remote(downloader, task))
ray.get(ret)

except ModuleNotFoundError as e:

def ray_distributor(processes_count, downloader, reader, subjob_size, max_shard_retry): # type: ignore # pylint: disable=unused-argument
return None


@contextmanager
def _spark_session(processes_count: int):
"""Create and close a spark session if none exist"""
Expand Down
8 changes: 7 additions & 1 deletion img2dataset/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
)
from .reader import Reader
from .downloader import Downloader
from .distributor import multiprocessing_distributor, pyspark_distributor
from .distributor import (
multiprocessing_distributor,
pyspark_distributor,
ray_distributor,
)
import fsspec
import sys
import signal
Expand Down Expand Up @@ -244,6 +248,8 @@ def signal_handler(signal_arg, frame): # pylint: disable=unused-argument
distributor_fn = multiprocessing_distributor
elif distributor == "pyspark":
distributor_fn = pyspark_distributor
elif distributor == "ray":
distributor_fn = ray_distributor
else:
raise ValueError(f"Distributor {distributor} not supported")

Expand Down
4 changes: 2 additions & 2 deletions img2dataset/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def __init__(self, output_file, schema, buffer_size=100):
self.schema = schema
self._initiatlize_buffer()
fs, output_path = fsspec.core.url_to_fs(output_file)

self.output_fd = fs.open(output_path, "wb")
# testing for S3
self.output_fd = fs.open(output_path, "wb", blocksize=200000000)
self.parquet_writer = pq.ParquetWriter(self.output_fd, schema)

def _initiatlize_buffer(self):
Expand Down
1 change: 1 addition & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ tensorflow
tensorflow_io
types-requests
types-pkg_resources
ray
1 change: 1 addition & 0 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def test_relative_path(tmp_path):
[
"multiprocessing",
"pyspark",
"ray",
],
)
def test_distributors(distributor, tmp_path):
Expand Down
Loading