Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,17 @@ def _build_state_json(

# Collect commit order (newest → older) across signals
commits: List[str] = []
commit_times: Dict[str, str] = {}
seen = set()
for s in signals:
for c in s.commits:
if c.head_sha not in seen:
seen.add(c.head_sha)
commits.append(c.head_sha)
commit_times[c.head_sha] = c.timestamp.isoformat()

# Compute minimal started_at per commit (for timestamp context)
commit_times: Dict[str, str] = {}
for sha in commits:
tmin_iso: str | None = None
for s in signals:
# find commit in this signal
sc = next((cc for cc in s.commits if cc.head_sha == sha), None)
if not sc or not sc.events:
continue
# events are sorted oldest first
t = sc.events[0].started_at
ts_iso = t.isoformat()
if tmin_iso is None or ts_iso < tmin_iso:
tmin_iso = ts_iso
if tmin_iso is not None:
commit_times[sha] = tmin_iso
# sorting commits by their timestamp
commits.sort(key=lambda sha: commit_times[sha], reverse=True)

# Build columns with outcomes, notes, and per-commit events
cols = []
Expand Down
3 changes: 2 additions & 1 deletion aws/lambda/pytorch-auto-revert/pytorch_auto_revert/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def is_failure(self) -> bool:
class SignalCommit:
"""All events for a single commit, ordered oldest → newest by start time."""

def __init__(self, head_sha: str, events: List[SignalEvent]):
def __init__(self, head_sha: str, timestamp: datetime, events: List[SignalEvent]):
self.head_sha = head_sha
self.timestamp = timestamp
# enforce events ordered by time, then by wf_run_id (oldest first)
self.events = (
sorted(events, key=lambda e: (e.started_at, e.wf_run_id)) if events else []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,18 @@ def _fmt_event_name(
# -----------------------------
def extract(self) -> List[Signal]:
"""Extract Signals for configured workflows within the lookback window."""
# Fetch commits first to ensure we include commits without jobs
commits = self._datasource.fetch_commits_in_time_range(
repo_full_name=self.repo_full_name,
lookback_hours=self.lookback_hours,
)

# Fetch jobs for these commits
jobs = self._datasource.fetch_jobs_for_workflows(
repo_full_name=self.repo_full_name,
workflows=self.workflows,
lookback_hours=self.lookback_hours,
head_shas=[sha for sha, _ in commits],
)

# Select jobs to participate in test-track details fetch
Expand All @@ -76,8 +84,8 @@ def extract(self) -> List[Signal]:
test_track_job_ids, failed_job_ids=failed_job_ids
)

test_signals = self._build_test_signals(jobs, test_rows)
job_signals = self._build_non_test_signals(jobs)
test_signals = self._build_test_signals(jobs, test_rows, commits)
job_signals = self._build_non_test_signals(jobs, commits)
# Deduplicate events within commits across all signals as a final step
# GitHub-specific behavior like "rerun failed" can reuse job instances for reruns.
# When that happens, the jobs have identical timestamps by DIFFERENT job ids.
Expand All @@ -101,7 +109,11 @@ def _dedup_signal_events(self, signals: List[Signal]) -> List[Signal]:
continue
filtered.append(e)
prev_key = key
new_commits.append(SignalCommit(head_sha=c.head_sha, events=filtered))
new_commits.append(
SignalCommit(
head_sha=c.head_sha, timestamp=c.timestamp, events=filtered
)
)
deduped.append(
Signal(key=s.key, workflow_name=s.workflow_name, commits=new_commits)
)
Expand Down Expand Up @@ -145,6 +157,7 @@ def _build_test_signals(
self,
jobs: List[JobRow],
test_rows: List[TestRow],
commits: List[Tuple[Sha, datetime]],
) -> List[Signal]:
"""Build per-test Signals across commits, scoped to job base.

Expand All @@ -155,9 +168,15 @@ def _build_test_signals(
- If test_run_s3 rows exist → FAILURE if any failing/errored else SUCCESS
- Else if group pending → PENDING
- Else → no event (missing)

Args:
jobs: List of job rows from the datasource
test_rows: List of test rows from the datasource
commits: Ordered list of (sha, timestamp) tuples (newest → older)
"""

jobs_by_id = {j.job_id: j for j in jobs}
commit_timestamps = dict(commits)

index_by_commit_job_base_wf_run_attempt: JobAggIndex[
Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]
Expand All @@ -172,11 +191,6 @@ def _build_test_signals(
),
)

# Preserve newest → older commit order from the datasource
commit_shas = index_by_commit_job_base_wf_run_attempt.unique_values(
lambda j: j.head_sha
)

run_ids_attempts = index_by_commit_job_base_wf_run_attempt.group_map_values_by(
key_fn=lambda j: (j.head_sha, j.workflow_name, j.base_name),
value_fn=lambda j: (j.wf_run_id, j.run_attempt),
Expand Down Expand Up @@ -225,7 +239,7 @@ def _build_test_signals(
)

# y-axis: commits (newest → older)
for commit_sha in commit_shas:
for commit_sha, _ in commits:
events: List[SignalEvent] = []

# x-axis: events for the signal
Expand Down Expand Up @@ -286,7 +300,13 @@ def _build_test_signals(
has_any_events = True

# important to always include the commit, even if no events
commit_objs.append(SignalCommit(head_sha=commit_sha, events=events))
commit_objs.append(
SignalCommit(
head_sha=commit_sha,
timestamp=commit_timestamps[commit_sha],
events=events,
)
)

if has_any_events:
signals.append(
Expand All @@ -295,9 +315,19 @@ def _build_test_signals(

return signals

def _build_non_test_signals(self, jobs: List[JobRow]) -> List[Signal]:
# Build Signals keyed by normalized job base name per workflow.
# Aggregate across shards within (wf_run_id, run_attempt) using JobAggIndex.
def _build_non_test_signals(
self, jobs: List[JobRow], commits: List[Tuple[Sha, datetime]]
) -> List[Signal]:
"""Build Signals keyed by normalized job base name per workflow.

Aggregate across shards within (wf_run_id, run_attempt) using JobAggIndex.

Args:
jobs: List of job rows from the datasource
commits: Ordered list of (sha, timestamp) tuples (newest → older)
"""

commit_timestamps = dict(commits)

index = JobAggIndex.from_rows(
jobs,
Expand All @@ -310,9 +340,6 @@ def _build_non_test_signals(self, jobs: List[JobRow]) -> List[Signal]:
),
)

# Preserve commit order as first-seen in the job rows (datasource orders newest→older).
commit_shas = index.unique_values(lambda j: j.head_sha)

# Map (sha, workflow, base) -> [attempt_keys]
groups_index = index.group_keys_by(
key_fn=lambda j: (j.head_sha, j.workflow_name, j.base_name)
Expand All @@ -329,7 +356,7 @@ def _build_non_test_signals(self, jobs: List[JobRow]) -> List[Signal]:
# Track failure types across all attempts/commits for this base
has_relevant_failures = False # at least one non-test failure observed

for sha in commit_shas:
for sha, _ in commits:
attempt_keys: List[
Tuple[Sha, WorkflowName, JobBaseName, WfRunId, RunAttempt]
] = groups_index.get((sha, wf_name, base_name), [])
Expand Down Expand Up @@ -374,7 +401,11 @@ def _build_non_test_signals(self, jobs: List[JobRow]) -> List[Signal]:
)

# important to always include the commit, even if no events
commit_objs.append(SignalCommit(head_sha=sha, events=events))
commit_objs.append(
SignalCommit(
head_sha=sha, timestamp=commit_timestamps[sha], events=events
)
)

# Emit job signal when failures were present and failures were NOT exclusively test-caused
if has_relevant_failures:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,65 @@ class SignalExtractionDatasource:
Encapsulates ClickHouse queries used by the signal extraction layer.
"""

def fetch_commits_in_time_range(
self, *, repo_full_name: str, lookback_hours: int
) -> List[tuple[Sha, datetime]]:
"""
Fetch all commits pushed to main within the lookback window.
Returns list of (sha, timestamp) tuples ordered newest → older.
"""
lookback_time = datetime.now() - timedelta(hours=lookback_hours)

query = """
SELECT head_commit.id AS sha, max(head_commit.timestamp) AS ts
FROM default.push
WHERE head_commit.timestamp >= {lookback_time:DateTime}
AND ref = 'refs/heads/main'
AND dynamoKey like {repo:String}
GROUP BY sha
ORDER BY ts DESC
"""

params = {
"lookback_time": lookback_time,
"repo": f"{repo_full_name}%",
}

log = logging.getLogger(__name__)
log.info(
"[extract] Fetching commits in time range: repo=%s lookback=%sh",
repo_full_name,
lookback_hours,
)
t0 = time.perf_counter()
for attempt in RetryWithBackoff():
with attempt:
res = CHCliFactory().client.query(query, parameters=params)
commits = [(Sha(row[0]), row[1]) for row in res.result_rows]
dt = time.perf_counter() - t0
log.info("[extract] Commits fetched: %d commits in %.2fs", len(commits), dt)
return commits

def fetch_jobs_for_workflows(
self, *, repo_full_name: str, workflows: Iterable[str], lookback_hours: int
self,
*,
repo_full_name: str,
workflows: Iterable[str],
lookback_hours: int,
head_shas: List[Sha],
) -> List[JobRow]:
"""
Fetch recent workflow job rows for the given workflows within the lookback window.
Fetch workflow job rows for the given head_shas and workflows.

Returns rows ordered by push timestamp desc, then by workflow run/job identity.
Returns rows ordered by head_sha (following the order of head_shas), then by started_at ASC.
"""
lookback_time = datetime.now() - timedelta(hours=lookback_hours)

workflow_filter = ""
params: Dict[str, Any] = {
"lookback_time": lookback_time,
"repo": repo_full_name,
"head_shas": [str(s) for s in head_shas],
}
workflow_list = list(workflows)
if workflow_list:
Expand All @@ -55,13 +100,6 @@ def fetch_jobs_for_workflows(
# the extractor and downstream logic rely on the KG-adjusted value so
# that pending jobs can also be recognized as failures-in-progress.
query = f"""
WITH push_dedup AS (
SELECT head_commit.id AS sha, max(head_commit.timestamp) AS ts
FROM default.push
WHERE head_commit.timestamp >= {{lookback_time:DateTime}}
AND ref = 'refs/heads/main'
GROUP BY sha
)
SELECT
wf.head_sha,
wf.workflow_name,
Expand All @@ -76,19 +114,20 @@ def fetch_jobs_for_workflows(
wf.created_at,
tupleElement(wf.torchci_classification_kg,'rule') AS rule
FROM default.workflow_job AS wf FINAL
INNER JOIN push_dedup p ON wf.head_sha = p.sha
WHERE wf.repository_full_name = {{repo:String}}
AND wf.head_sha IN {{head_shas:Array(String)}}
AND wf.created_at >= {{lookback_time:DateTime}}
AND (wf.name NOT LIKE '%mem_leak_check%' AND wf.name NOT LIKE '%rerun_disabled_tests%')
{workflow_filter}
ORDER BY p.ts DESC, wf.started_at ASC, wf.head_sha, wf.run_id, wf.run_attempt, wf.name
ORDER BY wf.head_sha, wf.started_at ASC, wf.run_id, wf.run_attempt, wf.name
"""

log = logging.getLogger(__name__)
log.info(
"[extract] Fetching jobs: repo=%s workflows=%s lookback=%sh",
"[extract] Fetching jobs: repo=%s workflows=%s commits=%d lookback=%sh",
repo_full_name,
",".join(workflow_list) if workflow_list else "<all>",
len(head_shas),
lookback_hours,
)
t0 = time.perf_counter()
Expand Down
Loading