Skip to content

Commit

Permalink
update api cache format
Browse files Browse the repository at this point in the history
  • Loading branch information
jstzwj committed Aug 15, 2024
1 parent ba65afd commit 5b74eba
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 85 deletions.
43 changes: 1 addition & 42 deletions olah/proxy/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ORIGINAL_LOC,
)
from olah.cache.olah_cache import OlahCache
from olah.utils.cache_utils import _read_cache_request, _write_cache_request
from olah.utils.url_utils import (
RemoteInfo,
add_query_param,
Expand Down Expand Up @@ -80,48 +81,6 @@ def get_contiguous_ranges(
range_start_pos = end_pos
return ranges_and_cache_list


async def _write_cache_request(
head_path: str, status_code: int, headers: Dict[str, str], content: bytes
) -> None:
"""
Write the request's status code, headers, and content to a cache file.
Args:
head_path (str): The path to the cache file.
status_code (int): The status code of the request.
headers (Dict[str, str]): The dictionary of response headers.
content (bytes): The content of the request.
Returns:
None
"""
rq = {
"status_code": status_code,
"headers": headers,
"content": content.hex(),
}
with open(head_path, "w", encoding="utf-8") as f:
f.write(json.dumps(rq, ensure_ascii=False))


async def _read_cache_request(head_path: str) -> Dict[str, str]:
"""
Read the request's status code, headers, and content from a cache file.
Args:
head_path (str): The path to the cache file.
Returns:
Dict[str, str]: A dictionary containing the status code, headers, and content of the request.
"""
with open(head_path, "r", encoding="utf-8") as f:
rq = json.loads(f.read())

rq["content"] = bytes.fromhex(rq["content"])
return rq


async def _file_full_header(
app,
save_path: str,
Expand Down
39 changes: 23 additions & 16 deletions olah/proxy/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,16 @@
import httpx
from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT

from olah.utils.cache_utils import _read_cache_request, _write_cache_request
from olah.utils.rule_utils import check_cache_rules_hf
from olah.utils.repo_utils import get_org_repo
from olah.utils.file_utils import make_dirs


async def _meta_cache_generator(save_path: str):
yield {}
with open(save_path, "rb") as f:
while True:
chunk = f.read(CHUNK_SIZE)
if not chunk:
break
yield chunk
cache_rq = await _read_cache_request(save_path)
yield cache_rq["headers"]
yield cache_rq["content"]


async def meta_proxy_cache(
Expand All @@ -38,12 +35,16 @@ async def meta_proxy_cache(
commit: str,
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")

# save
method = request.method.lower()
repos_path = app.app_settings.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}"
)
save_path = os.path.join(save_dir, "meta.json")
save_path = os.path.join(save_dir, f"meta_{method}.json")
make_dirs(save_path)

# url
Expand All @@ -57,15 +58,16 @@ async def meta_proxy_cache(
headers["authorization"] = request.headers["authorization"]
async with httpx.AsyncClient() as client:
response = await client.request(
method="GET",
method=request.method,
url=meta_url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
follow_redirects=True,
)
if response.status_code == 200:
with open(save_path, "wb") as meta_file:
meta_file.write(response.content)
await _write_cache_request(
save_path, response.status_code, response.headers, response.content
)
else:
raise Exception(
f"Cannot get the branch info from the url {meta_url}, status: {response.status_code}"
Expand All @@ -77,16 +79,18 @@ async def _meta_proxy_generator(
headers: Dict[str, str],
meta_url: str,
allow_cache: bool,
method: str,
save_path: str,
):
async with httpx.AsyncClient(follow_redirects=True) as client:
content_chunks = []
async with client.stream(
method="GET",
method=method,
url=meta_url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_status_code = response.status_code
response_headers = response.headers
yield response_headers

Expand All @@ -99,8 +103,10 @@ async def _meta_proxy_generator(
content = bytearray()
for chunk in content_chunks:
content += chunk
with open(save_path, "wb") as f:
f.write(bytes(content))

await _write_cache_request(
save_path, response_status_code, response_headers, bytes(content)
)


async def meta_generator(
Expand All @@ -115,11 +121,12 @@ async def meta_generator(
headers.pop("host")

# save
method = request.method.lower()
repos_path = app.app_settings.repos_path
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/revision/{commit}"
)
save_path = os.path.join(save_dir, "meta.json")
save_path = os.path.join(save_dir, f"meta_{method}.json")
make_dirs(save_path)

use_cache = os.path.exists(save_path)
Expand All @@ -136,6 +143,6 @@ async def meta_generator(
yield item
else:
async for item in _meta_proxy_generator(
app, headers, meta_url, allow_cache, save_path
app, headers, meta_url, allow_cache, method, save_path
):
yield item
47 changes: 29 additions & 18 deletions olah/proxy/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,16 @@
import httpx
from olah.constants import CHUNK_SIZE, WORKER_API_TIMEOUT

from olah.utils.cache_utils import _read_cache_request, _write_cache_request
from olah.utils.rule_utils import check_cache_rules_hf
from olah.utils.repo_utils import get_org_repo
from olah.utils.file_utils import make_dirs


async def _tree_cache_generator(save_path: str):
yield {}
with open(save_path, "rb") as f:
while True:
chunk = f.read(CHUNK_SIZE)
if not chunk:
break
yield chunk
cache_rq = await _read_cache_request(save_path)
yield cache_rq["headers"]
yield cache_rq["content"]


async def tree_proxy_cache(
Expand All @@ -38,10 +35,16 @@ async def tree_proxy_cache(
commit: str,
request: Request,
):
headers = {k: v for k, v in request.headers.items()}
headers.pop("host")

# save
method = request.method.lower()
repos_path = app.app_settings.repos_path
save_dir = os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}")
save_path = os.path.join(save_dir, "tree.json")
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}"
)
save_path = os.path.join(save_dir, f"tree_{method}.json")
make_dirs(save_path)

# url
Expand All @@ -55,15 +58,16 @@ async def tree_proxy_cache(
headers["authorization"] = request.headers["authorization"]
async with httpx.AsyncClient() as client:
response = await client.request(
method="GET",
method=request.method,
url=tree_url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
follow_redirects=True,
)
if response.status_code == 200:
with open(save_path, "wb") as tree_file:
tree_file.write(response.content)
await _write_cache_request(
save_path, response.status_code, response.headers, response.content
)
else:
raise Exception(
f"Cannot get the branch info from the url {tree_url}, status: {response.status_code}"
Expand All @@ -75,16 +79,18 @@ async def _tree_proxy_generator(
headers: Dict[str, str],
tree_url: str,
allow_cache: bool,
method: str,
save_path: str,
):
async with httpx.AsyncClient(follow_redirects=True) as client:
content_chunks = []
async with client.stream(
method="GET",
method=method,
url=tree_url,
headers=headers,
timeout=WORKER_API_TIMEOUT,
) as response:
response_status_code = response.status_code
response_headers = response.headers
yield response_headers

Expand All @@ -97,8 +103,10 @@ async def _tree_proxy_generator(
content = bytearray()
for chunk in content_chunks:
content += chunk
with open(save_path, "wb") as f:
f.write(bytes(content))

await _write_cache_request(
save_path, response_status_code, response_headers, bytes(content)
)


async def tree_generator(
Expand All @@ -113,9 +121,12 @@ async def tree_generator(
headers.pop("host")

# save
method = request.method.lower()
repos_path = app.app_settings.repos_path
save_dir = os.path.join(repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}")
save_path = os.path.join(save_dir, "tree.json")
save_dir = os.path.join(
repos_path, f"api/{repo_type}/{org}/{repo}/tree/{commit}"
)
save_path = os.path.join(save_dir, f"tree_{method}.json")
make_dirs(save_path)

use_cache = os.path.exists(save_path)
Expand All @@ -132,6 +143,6 @@ async def tree_generator(
yield item
else:
async for item in _tree_proxy_generator(
app, headers, tree_url, allow_cache, save_path
app, headers, tree_url, allow_cache, method, save_path
):
yield item
14 changes: 11 additions & 3 deletions olah/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,30 +173,36 @@ async def meta_proxy_common(repo_type: str, org: str, repo: str, commit: str, re
return Response(status_code=504)


@app.head("/api/{repo_type}/{org_repo}")
@app.get("/api/{repo_type}/{org_repo}")
async def meta_proxy(repo_type: str, org_repo: str, request: Request):
org, repo = parse_org_repo(org_repo)
if org is None and repo is None:
return error_repo_not_found()
if not app.app_settings.config.offline:
new_commit = await get_newest_commit_hf(app, repo_type, org, repo)
if new_commit is None:
return error_repo_not_found()
else:
new_commit = "main"
return await meta_proxy_common(
repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request
)

@app.head("/api/{repo_type}/{org}/{repo}")
@app.get("/api/{repo_type}/{org}/{repo}")
async def meta_proxy(repo_type: str, org: str, repo: str, request: Request):
if not app.app_settings.config.offline:
new_commit = await get_newest_commit_hf(app, repo_type, org, repo)
if new_commit is None:
return error_repo_not_found()
else:
new_commit = "main"
return await meta_proxy_common(
repo_type=repo_type, org=org, repo=repo, commit=new_commit, request=request
)


@app.head("/api/{repo_type}/{org}/{repo}/revision/{commit}")
@app.get("/api/{repo_type}/{org}/{repo}/revision/{commit}")
async def meta_proxy_commit2(
repo_type: str, org: str, repo: str, commit: str, request: Request
Expand All @@ -205,7 +211,7 @@ async def meta_proxy_commit2(
repo_type=repo_type, org=org, repo=repo, commit=commit, request=request
)


@app.head("/api/{repo_type}/{org_repo}/revision/{commit}")
@app.get("/api/{repo_type}/{org_repo}/revision/{commit}")
async def meta_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request):
org, repo = parse_org_repo(org_repo)
Expand All @@ -230,6 +236,7 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re
git_path = os.path.join(mirror_path, repo_type, org, repo)
if os.path.exists(git_path):
local_repo = LocalMirrorRepo(git_path, repo_type, org, repo)
# TODO: Local git repo trees
tree_data = local_repo.get_tree(commit)
if tree_data is None:
continue
Expand Down Expand Up @@ -269,7 +276,7 @@ async def tree_proxy_common(repo_type: str, org: str, repo: str, commit: str, re
traceback.print_exc()
return Response(status_code=504)


@app.head("/api/{repo_type}/{org}/{repo}/tree/{commit}")
@app.get("/api/{repo_type}/{org}/{repo}/tree/{commit}")
async def tree_proxy_commit2(
repo_type: str, org: str, repo: str, commit: str, request: Request
Expand All @@ -279,6 +286,7 @@ async def tree_proxy_commit2(
)


@app.head("/api/{repo_type}/{org_repo}/tree/{commit}")
@app.get("/api/{repo_type}/{org_repo}/tree/{commit}")
async def tree_proxy_commit(repo_type: str, org_repo: str, commit: str, request: Request):
org, repo = parse_org_repo(org_repo)
Expand Down
Loading

0 comments on commit 5b74eba

Please sign in to comment.