Skip to content

Commit 880262d

Browse files
authored
Merge pull request #16393 from github/redsun82/lfs
Bazel: improved lazy lfs files
2 parents 95ff5ba + 6cbe16e commit 880262d

File tree

4 files changed

+221
-0
lines changed

4 files changed

+221
-0
lines changed

.lfsconfig

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
[lfs]
2+
# codeql is publicly forked by many users, and we don't want any LFS file polluting their working
3+
# copies. We therefore exclude everything by default.
4+
# For files required by bazel builds, use rules in `misc/bazel/lfs.bzl` to download them on demand.
5+
fetchinclude = /nothing

misc/bazel/internal/BUILD.bazel

Whitespace-only changes.

misc/bazel/internal/git_lfs_probe.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
#!/usr/bin/env python3
2+
3+
"""
4+
Probe lfs files.
5+
For each source file provided as output, this will print:
6+
* "local", if the source file is not an LFS pointer
7+
* the sha256 hash, a space character and a transient download link obtained via the LFS protocol otherwise
8+
"""
9+
10+
import sys
11+
import pathlib
12+
import subprocess
13+
import os
14+
import shutil
15+
import json
16+
import urllib.request
17+
from urllib.parse import urlparse
18+
import re
19+
import base64
20+
from dataclasses import dataclass
21+
22+
23+
@dataclass
24+
class Endpoint:
25+
href: str
26+
headers: dict[str, str]
27+
28+
def update_headers(self, d: dict[str, str]):
29+
self.headers.update((k.capitalize(), v) for k, v in d.items())
30+
31+
32+
sources = [pathlib.Path(arg).resolve() for arg in sys.argv[1:]]
33+
source_dir = pathlib.Path(os.path.commonpath(src.parent for src in sources))
34+
source_dir = subprocess.check_output(["git", "rev-parse", "--show-toplevel"], cwd=source_dir, text=True).strip()
35+
36+
37+
def get_env(s, sep="="):
38+
ret = {}
39+
for m in re.finditer(fr'(.*?){sep}(.*)', s, re.M):
40+
ret.setdefault(*m.groups())
41+
return ret
42+
43+
44+
def git(*args, **kwargs):
45+
return subprocess.run(("git",) + args, stdout=subprocess.PIPE, text=True, cwd=source_dir, **kwargs).stdout.strip()
46+
47+
48+
def get_endpoint():
49+
lfs_env = get_env(subprocess.check_output(["git", "lfs", "env"], text=True, cwd=source_dir))
50+
endpoint = next(v for k, v in lfs_env.items() if k.startswith('Endpoint'))
51+
endpoint, _, _ = endpoint.partition(' ')
52+
ssh_endpoint = lfs_env.get(" SSH")
53+
endpoint = Endpoint(endpoint, {
54+
"Content-Type": "application/vnd.git-lfs+json",
55+
"Accept": "application/vnd.git-lfs+json",
56+
})
57+
if ssh_endpoint:
58+
# see https://github.com/git-lfs/git-lfs/blob/main/docs/api/authentication.md
59+
server, _, path = ssh_endpoint.partition(":")
60+
ssh_command = shutil.which(os.environ.get("GIT_SSH", os.environ.get("GIT_SSH_COMMAND", "ssh")))
61+
assert ssh_command, "no ssh command found"
62+
resp = json.loads(subprocess.check_output([ssh_command, server, "git-lfs-authenticate", path, "download"]))
63+
endpoint.href = resp.get("href", endpoint)
64+
endpoint.update_headers(resp.get("header", {}))
65+
url = urlparse(endpoint.href)
66+
# this is how actions/checkout persist credentials
67+
# see https://github.com/actions/checkout/blob/44c2b7a8a4ea60a981eaca3cf939b5f4305c123b/src/git-auth-helper.ts#L56-L63
68+
auth = git("config", f"http.{url.scheme}://{url.netloc}/.extraheader")
69+
endpoint.update_headers(get_env(auth, sep=": "))
70+
if "GITHUB_TOKEN" in os.environ:
71+
endpoint.headers["Authorization"] = f"token {os.environ['GITHUB_TOKEN']}"
72+
if "Authorization" not in endpoint.headers:
73+
# last chance: use git credentials (possibly backed by a credential helper like the one installed by gh)
74+
# see https://git-scm.com/docs/git-credential
75+
credentials = get_env(git("credential", "fill", check=True,
76+
# drop leading / from url.path
77+
input=f"protocol={url.scheme}\nhost={url.netloc}\npath={url.path[1:]}\n"))
78+
auth = base64.b64encode(f'{credentials["username"]}:{credentials["password"]}'.encode()).decode('ascii')
79+
endpoint.headers["Authorization"] = f"Basic {auth}"
80+
return endpoint
81+
82+
83+
# see https://github.com/git-lfs/git-lfs/blob/310d1b4a7d01e8d9d884447df4635c7a9c7642c2/docs/api/basic-transfers.md
84+
def get_locations(objects):
85+
endpoint = get_endpoint()
86+
indexes = [i for i, o in enumerate(objects) if o]
87+
ret = ["local" for _ in objects]
88+
req = urllib.request.Request(
89+
f"{endpoint.href}/objects/batch",
90+
headers=endpoint.headers,
91+
data=json.dumps({
92+
"operation": "download",
93+
"transfers": ["basic"],
94+
"objects": [o for o in objects if o],
95+
"hash_algo": "sha256",
96+
}).encode("ascii"),
97+
)
98+
with urllib.request.urlopen(req) as resp:
99+
data = json.load(resp)
100+
assert len(data["objects"]) == len(indexes), f"received {len(data)} objects, expected {len(indexes)}"
101+
for i, resp in zip(indexes, data["objects"]):
102+
ret[i] = f'{resp["oid"]} {resp["actions"]["download"]["href"]}'
103+
return ret
104+
105+
106+
def get_lfs_object(path):
107+
with open(path, 'rb') as fileobj:
108+
lfs_header = "version https://git-lfs.github.com/spec".encode()
109+
actual_header = fileobj.read(len(lfs_header))
110+
sha256 = size = None
111+
if lfs_header != actual_header:
112+
return None
113+
data = get_env(fileobj.read().decode('ascii'), sep=' ')
114+
assert data['oid'].startswith('sha256:'), f"unknown oid type: {data['oid']}"
115+
_, _, sha256 = data['oid'].partition(':')
116+
size = int(data['size'])
117+
return {"oid": sha256, "size": size}
118+
119+
120+
objects = [get_lfs_object(src) for src in sources]
121+
for resp in get_locations(objects):
122+
print(resp)

misc/bazel/lfs.bzl

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
def lfs_smudge(repository_ctx, srcs, extract = False, stripPrefix = None):
2+
for src in srcs:
3+
repository_ctx.watch(src)
4+
script = Label("//misc/bazel/internal:git_lfs_probe.py")
5+
python = repository_ctx.which("python3") or repository_ctx.which("python")
6+
if not python:
7+
fail("Neither python3 nor python executables found")
8+
repository_ctx.report_progress("querying LFS url(s) for: %s" % ", ".join([src.basename for src in srcs]))
9+
res = repository_ctx.execute([python, script] + srcs, quiet = True)
10+
if res.return_code != 0:
11+
fail("git LFS probing failed while instantiating @%s:\n%s" % (repository_ctx.name, res.stderr))
12+
promises = []
13+
for src, loc in zip(srcs, res.stdout.splitlines()):
14+
if loc == "local":
15+
if extract:
16+
repository_ctx.report_progress("extracting local %s" % src.basename)
17+
repository_ctx.extract(src, stripPrefix = stripPrefix)
18+
else:
19+
repository_ctx.report_progress("symlinking local %s" % src.basename)
20+
repository_ctx.symlink(src, src.basename)
21+
else:
22+
sha256, _, url = loc.partition(" ")
23+
if extract:
24+
# we can't use skylib's `paths.split_extension`, as that only gets the last extension, so `.tar.gz`
25+
# or similar wouldn't work
26+
# it doesn't matter if file is something like some.name.zip and possible_extension == "name.zip",
27+
# download_and_extract will just append ".name.zip" its internal temporary name, so extraction works
28+
possible_extension = ".".join(src.basename.rsplit(".", 2)[-2:])
29+
repository_ctx.report_progress("downloading and extracting remote %s" % src.basename)
30+
repository_ctx.download_and_extract(url, sha256 = sha256, stripPrefix = stripPrefix, type = possible_extension)
31+
else:
32+
repository_ctx.report_progress("downloading remote %s" % src.basename)
33+
repository_ctx.download(url, src.basename, sha256 = sha256)
34+
35+
def _download_and_extract_lfs(repository_ctx):
36+
attr = repository_ctx.attr
37+
src = repository_ctx.path(attr.src)
38+
if attr.build_file_content and attr.build_file:
39+
fail("You should specify only one among build_file_content and build_file for rule @%s" % repository_ctx.name)
40+
lfs_smudge(repository_ctx, [src], extract = True, stripPrefix = attr.strip_prefix)
41+
if attr.build_file_content:
42+
repository_ctx.file("BUILD.bazel", attr.build_file_content)
43+
elif attr.build_file:
44+
repository_ctx.symlink(attr.build_file, "BUILD.bazel")
45+
46+
def _download_lfs(repository_ctx):
47+
attr = repository_ctx.attr
48+
if int(bool(attr.srcs)) + int(bool(attr.dir)) != 1:
49+
fail("Exactly one between `srcs` and `dir` must be defined for @%s" % repository_ctx.name)
50+
if attr.srcs:
51+
srcs = [repository_ctx.path(src) for src in attr.srcs]
52+
else:
53+
dir = repository_ctx.path(attr.dir)
54+
if not dir.is_dir:
55+
fail("`dir` not a directory in @%s" % repository_ctx.name)
56+
srcs = [f for f in dir.readdir() if not f.is_dir]
57+
lfs_smudge(repository_ctx, srcs)
58+
59+
# with bzlmod the name is qualified with `~` separators, and we want the base name here
60+
name = repository_ctx.name.split("~")[-1]
61+
repository_ctx.file("BUILD.bazel", """
62+
exports_files({files})
63+
64+
filegroup(
65+
name = "{name}",
66+
srcs = {files},
67+
visibility = ["//visibility:public"],
68+
)
69+
""".format(name = name, files = repr([src.basename for src in srcs])))
70+
71+
lfs_archive = repository_rule(
72+
doc = "Export the contents from an on-demand LFS archive. The corresponding path should be added to be ignored " +
73+
"in `.lfsconfig`.",
74+
implementation = _download_and_extract_lfs,
75+
attrs = {
76+
"src": attr.label(mandatory = True, doc = "Local path to the LFS archive to extract."),
77+
"build_file_content": attr.string(doc = "The content for the BUILD file for this repository. " +
78+
"Either build_file or build_file_content can be specified, but not both."),
79+
"build_file": attr.label(doc = "The file to use as the BUILD file for this repository. " +
80+
"Either build_file or build_file_content can be specified, but not both."),
81+
"strip_prefix": attr.string(default = "", doc = "A directory prefix to strip from the extracted files. "),
82+
},
83+
)
84+
85+
lfs_files = repository_rule(
86+
doc = "Export LFS files for on-demand download. Exactly one between `srcs` and `dir` must be defined. The " +
87+
"corresponding paths should be added to be ignored in `.lfsconfig`.",
88+
implementation = _download_lfs,
89+
attrs = {
90+
"srcs": attr.label_list(doc = "Local paths to the LFS files to export."),
91+
"dir": attr.label(doc = "Local path to a directory containing LFS files to export. Only the direct contents " +
92+
"of the directory are exported"),
93+
},
94+
)

0 commit comments

Comments
 (0)