Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions .github/workflows/update_test_file_ratings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,26 @@ jobs:
user_email: "test-infra@pytorch.org"
user_name: "Pytorch Test Infra"
commit_message: "Updating TD heuristic: historical edited files"

update-ec2-pricing:
runs-on: linux.large
steps:
- name: Checkout test-infra repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2

- name: Install Dependencies
run: python3 -m pip install boto3==1.19.12 PyYAML==6.0

- name: Generate EC2 pricing data
run: |
python3 tools/torchci/test_insights/ec2_pricing.py

- name: Compress pricing file
run: |
gzip ec2_pricing.json

- name: Upload pricing file to S3
run: |
aws s3 cp ec2_pricing.json.gz s3://ossci-metrics/ec2_pricing.json.gz \
--content-encoding gzip \
--content-type application/json
115 changes: 115 additions & 0 deletions tools/torchci/test_insights/ec2_pricing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""
EC2 Pricing Map Generator

Get pricing info for EC2 instances by reading .github/scale-config.yml and
fetching current AWS pricing data.
"""

import json
from functools import lru_cache
from typing import Optional

import requests
import yaml


@lru_cache
def _get_scale_config() -> dict:
"""Load scale-config.yml and return as a dictionary."""
with open(".github/scale-config.yml", "r") as f:
config = yaml.safe_load(f)
return config


def get_ec2_instance_for_label(label: str) -> dict[str, Optional[str]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a note that this wouldn't work with H100 and teams are using this nowadays to run tests, e.g. https://github.com/pytorch/pytorch/blob/main/.github/workflows/h100-symm-mem.yml. Handling that is probably out of scope, but it's something we want to keep in mind I guess

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, this seems like an expected setup anyway because MacOS is also not here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not really sure how to calculate h100 or mac cost atm since if I understand correctly, it's a reservation so there isn't really a per hour or per second cost

"""Get EC2 instance type for a given GitHub Actions runner label from scale-config.yml."""
config = _get_scale_config()

runner_info = config.get("runner_types", {})

if label in runner_info:
return {
"ec2_instance": runner_info[label].get("instance_type", None),
"os": runner_info[label].get("os", "linux"),
} # Default to linux if not specified
return {"ec2_instance": None, "os": None}


@lru_cache
def get_all_pricing_data() -> dict:
"""Fetch the entire EC2 pricing data from AWS pricing API. Cached for efficiency."""
price_list_url = "https://pricing.us-east-1.amazonaws.com/offers/v1.0/aws/AmazonEC2/current/us-east-1/index.json"
response = requests.get(price_list_url)
response.raise_for_status()
return response.json()


@lru_cache
def get_price_for_ec2_instance(instance_type, os_type="linux") -> Optional[float]:
"""Fetch on-demand price for EC2 instance type using AWS public pricing data. Returns None if not found."""

# Map os_type to AWS pricing API values
operating_system = "Windows" if os_type.lower() == "windows" else "Linux"

# Get the cached pricing data
pricing_data = get_all_pricing_data()

# Search through the products to find matching instance
for product_sku, product_data in pricing_data.get("products", {}).items():
attributes = product_data.get("attributes", {})

if (
attributes.get("instanceType") == instance_type
and attributes.get("location") == "US East (N. Virginia)"
and attributes.get("operatingSystem") == operating_system
and attributes.get("preInstalledSw") == "NA"
and attributes.get("tenancy") == "Shared"
and attributes.get("usagetype", "").startswith("BoxUsage")
):
# Found the product, now get the pricing terms
terms = (
pricing_data.get("terms", {}).get("OnDemand", {}).get(product_sku, {})
)

for term_data in terms.values():
price_dimensions = term_data.get("priceDimensions", {})
for price_data in price_dimensions.values():
price_per_unit = price_data.get("pricePerUnit", {}).get("USD")
if price_per_unit:
return float(price_per_unit)

print(f"No pricing found for {instance_type} ({operating_system})")
return None


@lru_cache
def get_price_for_label(label: str) -> Optional[float]:
"""Get the on-demand price for the EC2 instance type associated with the given GitHub Actions runner label."""
instance_info = get_ec2_instance_for_label(label)
instance_type = instance_info["ec2_instance"]
os_type = instance_info["os"]
if instance_type is not None:
return get_price_for_ec2_instance(instance_type, os_type)
return None


if __name__ == "__main__":
# Example usage
info = []
scale_config = _get_scale_config()
for runner_label in scale_config.get("runner_types", {}):
price = get_price_for_label(runner_label)
info.append(
{
"label": runner_label,
"price_per_hour": price,
"instance_type": get_ec2_instance_for_label(runner_label)[
"ec2_instance"
],
}
)
with open("ec2_pricing.json", "w") as f:
for line in info:
json.dump(line, f)
f.write("\n")
24 changes: 6 additions & 18 deletions tools/torchci/test_insights/file_report_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
boto3 = None # type: ignore[assignment]

from torchci.clickhouse import query_clickhouse
from torchci.test_insights.ec2_pricing import get_price_for_label


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,21 +75,6 @@ def __init__(self, dry_run: bool = True):
"""Initialize the generator with the test owners file path"""
self.dry_run = dry_run

@lru_cache
def load_runner_costs(self) -> Dict[str, float]:
"""Load runner costs from the S3 endpoint"""
logger.debug("Fetching EC2 pricing data from S3...")
with urllib.request.urlopen(self.EC2_PRICING_URL) as response:
compressed_data = response.read()

decompressed_data = gzip.decompress(compressed_data)
pricing_data = {}
for line in decompressed_data.decode("utf-8").splitlines():
if line.strip():
line_json = json.loads(line)
pricing_data[line_json[0]] = float(line_json[2])
return pricing_data

@lru_cache
def load_test_owners(self) -> List[Dict[str, Any]]:
"""Load the test owner labels JSON file from S3"""
Expand All @@ -105,10 +91,12 @@ def load_test_owners(self) -> List[Dict[str, Any]]:

def get_runner_cost(self, runner_label: str) -> float:
"""Get the cost per hour for a given runner"""
runner_costs = self.load_runner_costs()
if runner_label.startswith("lf."):
runner_label = runner_label[3:]
return runner_costs.get(runner_label, 0.0)
cost = get_price_for_label(runner_label)
if cost is None:
return 0.0
return cost

def _get_first_suitable_sha(self, shas: list[dict[str, Any]]) -> Optional[str]:
"""Get the first suitable SHA from a list of SHAs."""
Expand Down Expand Up @@ -282,7 +270,7 @@ def _get_runner_label_from_job_info(self, job_info: Dict[str, Any]) -> str:
for label in job_labels:
if label.startswith("lf."):
label = label[3:]
if label in self.load_runner_costs():
if get_price_for_label(label) is not None:
return label

return "unknown"
Expand Down