In [1]:
import os


model_id = os.environ.get("PARAM_NAME")
if not model_id:
    raise ValueError("Missing required environment variable PARAM_NAME. Set `params: {name: hf_org/model_id} in the model spec` ")

output_dir = os.environ.get("OUTPUT_DIR", "/content/model")

# snapshot_download(repo_id=model_id, local_dir=output_dir, local_dir_use_symlinks=False, revision="main")

In [3]:
from huggingface_hub.hf_api import model_info

model = model_info(model_id)

filenames = [f.rfilename for f in model.siblings ]
filenames

['.gitattributes',
 'LICENSE.md',
 'README.md',
 'config.json',
 'flax_model.msgpack',
 'generation_config.json',
 'merges.txt',
 'pytorch_model.bin',
 'special_tokens_map.json',
 'tf_model.h5',
 'tokenizer_config.json',
 'vocab.json']

In [4]:
filenames = list(filter(lambda f: not f.startswith("coreml/"), filenames))
filenames

['.gitattributes',
 'LICENSE.md',
 'README.md',
 'config.json',
 'flax_model.msgpack',
 'generation_config.json',
 'merges.txt',
 'pytorch_model.bin',
 'special_tokens_map.json',
 'tf_model.h5',
 'tokenizer_config.json',
 'vocab.json']

In [5]:
import urllib.request
from huggingface_hub import hf_hub_url
from concurrent.futures import ThreadPoolExecutor, as_completed


def download_file(filename: str) -> str:
    destination = f"{output_dir}/{filename}"
    print(f"Downloading {filename} to {destination}")
    url = hf_hub_url(model_id, filename)
    urllib.request.urlretrieve(url, destination)
    return destination

processes = []
with ThreadPoolExecutor(max_workers=10) as executor:
    for filename in filenames:
        processes.append(executor.submit(download_file, filename))

for task in as_completed(processes):
    print(f"Finished downloading {task.result()}")

Downloading .gitattributes to /content/model/.gitattributes
Downloading LICENSE.md to /content/model/LICENSE.md
Downloading README.md to /content/model/README.md
Downloading config.json to /content/model/config.json
Downloading flax_model.msgpack to /content/model/flax_model.msgpack
Downloading generation_config.json to /content/model/generation_config.json
Downloading merges.txt to /content/model/merges.txt
Downloading pytorch_model.bin to /content/model/pytorch_model.bin
Downloading special_tokens_map.json to /content/model/special_tokens_map.json
Downloading tf_model.h5 to /content/model/tf_model.h5
Downloading tokenizer_config.json to /content/model/tokenizer_config.json
Downloading vocab.json to /content/model/vocab.json
Finished downloading /content/model/generation_config.json
Finished downloading /content/model/README.md
Finished downloading /content/model/flax_model.msgpack
Finished downloading /content/model/vocab.json
Finished downloading /content/model/LICENSE.md
Finished d

In [6]:
! ls -lash /content/model

total 724M
   0 drwxr-xr-x 14 root root  448 Jul 16 17:12 .
4.0K drwxr-xr-x  1 root root 4.0K Jul 16 17:11 ..
4.0K -rw-r--r--  1 root root 1.2K Jul 16 17:12 .gitattributes
 12K -rw-r--r--  1 root root  11K Jul 16 17:12 LICENSE.md
8.0K -rw-r--r--  1 root root 7.0K Jul 16 17:12 README.md
4.0K -rw-r--r--  1 root root  651 Jul 16 17:12 config.json
241M -rw-r--r--  1 root root 239M Jul 16 17:13 flax_model.msgpack
4.0K -rw-r--r--  1 root root  137 Jul 16 17:12 generation_config.json
448K -rw-r--r--  1 root root 446K Jul 16 17:12 merges.txt
241M -rw-r--r--  1 root root 239M Jul 16 17:13 pytorch_model.bin
4.0K -rw-r--r--  1 root root  441 Jul 16 17:12 special_tokens_map.json
241M -rw-r--r--  1 root root 240M Jul 16 17:13 tf_model.h5
4.0K -rw-r--r--  1 root root  685 Jul 16 17:12 tokenizer_config.json
880K -rw-r--r--  1 root root 878K Jul 16 17:12 vocab.json
