In [None]:
import torch, json, sys
from transformers import AutoModel, AutoConfig
from huggingface_hub import hf_hub_download
from typing import Dict, List, Tuple

In [None]:
class UnifiedIdMapper:
    def __init__(self, nodes: Dict[int, str], edges: Dict[int, str]) -> None:

        # since all key in JSON are str, convert them to int
        nodes = {int(k): v for k, v in nodes.items()}
        edges = {int(k): v for k, v in edges.items()}

        self.nodes = nodes
        self.edges = edges

        node_mapping = {old_id: new_id for new_id, old_id in enumerate(sorted(self.nodes.keys()))}
        edge_mapping = {old_id: new_id for new_id, old_id in enumerate(sorted(edges.keys()))}
        shift = len(nodes)

        self.old_to_new: Dict[int, Tuple[int, bool]] = {
            **{old_id: (new_id, False) for old_id, new_id in node_mapping.items()},
            **{old_id: (new_id + shift, True) for old_id, new_id in edge_mapping.items()},
        }
        # reverse mapping: new_id -> (old_id, is_edge)
        self.new_to_old: Dict[int, Tuple[int, bool]] = {
            new_id: (old_id, is_edge)
            for old_id, (new_id, is_edge) in self.old_to_new.items()
        }

        # Label maps
        self.old_id_to_label: Dict[int, str] = {**nodes, **edges}
        self.new_id_to_label: Dict[int, str] = {
            new_id: self.old_id_to_label[old_id] for old_id, (new_id, _) in self.old_to_new.items()
        }

        self.label_to_old_ids: Dict[str, List[Tuple[int, bool]]] = {}
        self.label_to_new_ids: Dict[str, List[Tuple[int, bool]]] = {}
        for old_id, (new_id, is_edge) in self.old_to_new.items():
            label = self.old_id_to_label.get(old_id)
            if label is None:
                continue
            self.label_to_old_ids.setdefault(label, []).append((old_id, is_edge))
            self.label_to_new_ids.setdefault(label, []).append((new_id, is_edge))

    @classmethod
    def from_file(cls, mapper_path: str):
        with open(mapper_path, "r") as f:
            data = json.load(f)
            return cls(data['nodes'], data['edges'])

    def map_old_id(self, old_id: int) -> Tuple[int, bool]:
        return self.old_to_new[old_id]

    def map_new_id(self, new_id: int) -> Tuple[int, bool]:
        return self.new_to_old[new_id]

    def label_from_old_id(self, old_id: int) -> str:
        return self.old_id_to_label[old_id]

    def label_from_new_id(self, new_id: int) -> str:
        return self.new_id_to_label[new_id]

    def old_ids_from_label(self, label: str) -> List[Tuple[int, bool]]:
        return self.label_to_old_ids.get(label, [])

    def new_ids_from_label(self, label: str) -> List[Tuple[int, bool]]:
        return self.label_to_new_ids.get(label, [])

In [None]:
class ModelWrapper:
    def __init__(self, mapper_path, model, device="cuda"):
        # Load Mapper
        print(f"Loading mapper from {mapper_path}...")
        self.mapper = UnifiedIdMapper.from_file(mapper_path)

        # set model
        self.model = model

        # Set device
        if device == "cuda" and not torch.cuda.is_available():
            print("CUDA not available, switching to CPU.")
            self.device = torch.device("cpu")
        elif device == "mps": # Handle MPS explicitly if requested or available
             self.device = torch.device("mps")
        else:
            self.device = torch.device(device)

        print(f"Moving model to {self.device}...")
        self.model.to(self.device)
        self.model.eval()

    def predict(self, old_ids_context):
        """
        Args:
            old_ids_context: List of old IDs defining the context.
        Returns:
            sorted_predictions: List of (prob, old_id, label) sorted by probability descending.
        """
        # 1. Convert context list of old IDs to new IDs
        input_ids = []
        for old_id in old_ids_context:
            # We assume the input old_ids exist in the mapper
            new_id, _ = self.mapper.map_old_id(old_id)
            input_ids.append(new_id)

        # 2. Run inference
        # Create tensor on result device (batch size = 1)
        model_input = torch.tensor([input_ids], dtype=torch.long, device=self.device)

        with torch.no_grad():
            logits = self.model(model_input)
            # Get logits for the last token in the sequence
            last_token_logits = logits[0, -1, :]
            probs = torch.softmax(last_token_logits, dim=-1)

        # 3. Sort by probability descending
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)

        sorted_probs = sorted_probs.tolist()
        sorted_indices = sorted_indices.tolist() # These indices are the new_ids

        # 4. Create result list with mapping applied
        results = []
        for prob, new_id in zip(sorted_probs, sorted_indices):
            try:
                # map_new_id returns (old_id, is_edge)
                old_id, _ = self.mapper.map_new_id(new_id)
                label = self.mapper.label_from_new_id(new_id)
                results.append((prob, old_id, label))
            except KeyError:
                # Handle indices not in mapper (e.g., padding tokens)
                results.append((prob, -1, "<PAD/UNK>"))

        return results

In [None]:
# Load Model with trust_remote_code=True
model_id = "crab27/llama3-edge"

# This loads the model and the custom configuration for hf
model = AutoModel.from_pretrained(model_id, trust_remote_code=True, dtype=torch.bfloat16)

# Load the UnifiedIdMapper
mapper_path = hf_hub_download(repo_id=model_id, filename="unified_id_mapper.json")

# Initialize the wrapper
wrapper = ModelWrapper(mapper_path, model)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

configuration_llama_edge.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/crab27/llama3-edge:
- configuration_llama_edge.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


modeling_llama_edge.py: 0.00B [00:00, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/crab27/llama3-edge:
- modeling_llama_edge.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 6 files:   0%|          | 0/6 [00:00<?, ?it/s]

model-00003-of-00006.safetensors:   0%|          | 0.00/4.83G [00:00<?, ?B/s]

model-00006-of-00006.safetensors:   0%|          | 0.00/3.48G [00:00<?, ?B/s]

model-00002-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00004-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00006.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00005-of-00006.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

unified_id_mapper.json: 0.00B [00:00, ?B/s]

Loading mapper from /root/.cache/huggingface/hub/models--crab27--llama3-edge/snapshots/3489e83582028c59150450867e78b3ddc9f9781e/unified_id_mapper.json...
Moving model to cuda...


In [None]:
# Input instance take from our training data
input_ids = [108, 112, 117, 349, 421, 608, 761, 765, 805, 912, 930, 937, 940, 1076, 1095, 1125, 1133, 1188, 1510, 1948, 1958, 47178924]
# our gold label
target_edge_id = 47185647
target_edge_label = "/location/location/contains"

# Wrapper will output a list of 9942 items
# Each items is in this format (prob, id, label)
predictions = wrapper.predict(input_ids)
print(f"Input old IDs: {input_ids}")
print(f"Target edge old ID: {target_edge_id}, Label: {target_edge_label}")

# Top 10 ids the model predict
print("Top 10 Predictions:")
for rank, (prob, id, pred_label) in enumerate(predictions[:10], start=1):
    print(f"  Rank {rank}: ID {id}, Label: {pred_label}, Probability: {prob:.6f}")

# Check if the gold label is in top 10
top_10_ids = [pred_id for _, pred_id, _ in predictions[:10]]
if target_edge_id in top_10_ids:
    print(f"Target edge old ID {target_edge_id} found in top 10 predictions.")
else:
    print(f"Target edge old ID {target_edge_id} NOT found in top 10 predictions.")
print("-" * 100)

# Note: Since output vector size is 9942 (total edge + node), there is a chance the model will assign high prob for class id
# During inference, perhap you can help the model out by masking out node ids as well as impossible edges (given the current graph)
# The model has no knowledge of freebase ontology(structure) whatsoever


Input old IDs: [108, 112, 117, 349, 421, 608, 761, 765, 805, 912, 930, 937, 940, 1076, 1095, 1125, 1133, 1188, 1510, 1948, 1958, 47178924]
Target edge old ID: 47185647, Label: /location/location/contains
Top 10 Predictions:
  Rank 1: ID 47185647, Label: /location/mailing_address/state_province_region-/location/mailing_address/citytown, Probability: 1.000000
  Rank 2: ID 47185782, Label: /finance/exchange_rate/target_of_exchange-/finance/exchange_rate/source_of_exchange, Probability: 0.000179
  Rank 3: ID 47178924, Label: /location/location/contains, Probability: 0.000051
  Rank 4: ID 47181636, Label: /people/ethnicity/geographic_distribution, Probability: 0.000035
  Rank 5: ID 47181637, Label: /people/ethnicity/languages_spoken, Probability: 0.000028
  Rank 6: ID 47185587, Label: /award/award_nomination/nominated_for-/award/award_nomination/ceremony, Probability: 0.000019
  Rank 7: ID 47181635, Label: /language/language_family/geographic_distribution, Probability: 0.000017
  Rank 8: ID

In [None]:
# More example to test out
{"context_ids": [108, 110, 112, 115, 117, 234, 349, 421, 541, 582, 608, 831, 905, 912, 913, 940, 994, 995, 996, 1019, 1076, 1077, 1095, 1125, 1181, 1188, 1251, 1275, 1278, 1287, 1291, 1304, 1393, 1415, 1542, 1685, 1721, 1790, 1942, 1948, 1958, 1999, 2000, 2025, 47185647, 47185648], "context_labels": ["/location/citytown", "/location/country", "/location/location", "/location/region", "/film/film_location", "/book/book_subject", "/business/business_location", "/business/employer", "/military/military_combatant", "/fictional_universe/fictional_setting", "/location/administrative_division", "/biology/organism_classification", "/wine/appellation", "/wine/wine_region", "/wine/wine_sub_region", "/government/governmental_jurisdiction", "/location/it_region", "/location/it_province", "/location/it_comune", "/organization/organization_member", "/location/statistical_region", "/organization/organization_founder", "/organization/organization_scope", "/location/dated_location", "/food/ingredient", "/travel/travel_destination", "/symbols/name_source", "/exhibitions/exhibition_subject", "/exhibitions/exhibition_venue", "/olympics/olympic_host_city", "/olympics/olympic_participating_country", "/food/beer_country_region", "/symbols/flag_referent", "/sports/sports_team_location", "/biology/breed_origin", "/food/food", "/sports/sport_country", "/meteorology/forecast_zone", "/media_common/netflix_genre", "/periodicals/newspaper_circulation_area", "/law/court_jurisdiction_area", "/medicine/drug_ingredient", "/medicine/drug_dosage_flavor", "/location/capital_of_administrative_division", "/location/mailing_address/state_province_region-/location/mailing_address/citytown", "/location/mailing_address/state_province_region-/location/mailing_address/country"], "target_edge_id": 47178924, "target_edge_label": "/location/location/contains"}
{"context_ids": [108, 110, 112, 115, 117, 234, 349, 421, 541, 582, 608, 831, 905, 912, 913, 940, 994, 995, 996, 1019, 1076, 1077, 1095, 1125, 1181, 1188, 1251, 1275, 1278, 1287, 1291, 1304, 1393, 1415, 1542, 1685, 1721, 1790, 1942, 1948, 1958, 1999, 2000, 2025, 47178924, 47185648], "context_labels": ["/location/citytown", "/location/country", "/location/location", "/location/region", "/film/film_location", "/book/book_subject", "/business/business_location", "/business/employer", "/military/military_combatant", "/fictional_universe/fictional_setting", "/location/administrative_division", "/biology/organism_classification", "/wine/appellation", "/wine/wine_region", "/wine/wine_sub_region", "/government/governmental_jurisdiction", "/location/it_region", "/location/it_province", "/location/it_comune", "/organization/organization_member", "/location/statistical_region", "/organization/organization_founder", "/organization/organization_scope", "/location/dated_location", "/food/ingredient", "/travel/travel_destination", "/symbols/name_source", "/exhibitions/exhibition_subject", "/exhibitions/exhibition_venue", "/olympics/olympic_host_city", "/olympics/olympic_participating_country", "/food/beer_country_region", "/symbols/flag_referent", "/sports/sports_team_location", "/biology/breed_origin", "/food/food", "/sports/sport_country", "/meteorology/forecast_zone", "/media_common/netflix_genre", "/periodicals/newspaper_circulation_area", "/law/court_jurisdiction_area", "/medicine/drug_ingredient", "/medicine/drug_dosage_flavor", "/location/capital_of_administrative_division", "/location/location/contains", "/location/mailing_address/state_province_region-/location/mailing_address/country"], "target_edge_id": 47185647, "target_edge_label": "/location/mailing_address/state_province_region-/location/mailing_address/citytown"}
{"context_ids": [108, 112, 117, 349, 421, 608, 761, 765, 805, 912, 930, 937, 940, 1076, 1095, 1125, 1133, 1188, 1510, 1948, 1958, 47185647], "context_labels": ["/location/citytown", "/location/location", "/film/film_location", "/business/business_location", "/business/employer", "/location/administrative_division", "/award/award_nominated_work", "/award/award_winning_work", "/music/composition", "/wine/wine_region", "/location/australian_state", "/government/political_district", "/government/governmental_jurisdiction", "/location/statistical_region", "/organization/organization_scope", "/location/dated_location", "/people/ethnicity", "/travel/travel_destination", "/geology/rock_type", "/periodicals/newspaper_circulation_area", "/law/court_jurisdiction_area", "/location/mailing_address/state_province_region-/location/mailing_address/citytown"], "target_edge_id": 47178924, "target_edge_label": "/location/location/contains"}
{"context_ids": [108, 112, 117, 349, 421, 608, 761, 765, 805, 912, 930, 937, 940, 1076, 1095, 1125, 1133, 1188, 1510, 1948, 1958, 47178924], "context_labels": ["/location/citytown", "/location/location", "/film/film_location", "/business/business_location", "/business/employer", "/location/administrative_division", "/award/award_nominated_work", "/award/award_winning_work", "/music/composition", "/wine/wine_region", "/location/australian_state", "/government/political_district", "/government/governmental_jurisdiction", "/location/statistical_region", "/organization/organization_scope", "/location/dated_location", "/people/ethnicity", "/travel/travel_destination", "/geology/rock_type", "/periodicals/newspaper_circulation_area", "/law/court_jurisdiction_area", "/location/location/contains"], "target_edge_id": 47185647, "target_edge_label": "/location/mailing_address/state_province_region-/location/mailing_address/citytown"}
{"context_ids": [108, 112, 117, 217, 234, 250, 349, 421, 453, 469, 528, 582, 628, 763, 764, 940, 1076, 1095, 1125, 1188, 1235, 1373, 1415, 1640, 1715, 1810, 1819, 1913, 1948, 2025, 47183681, 47185555], "context_labels": ["/location/citytown", "/location/location", "/film/film_location", "/basketball/basketball_team", "/book/book_subject", "/tv/tv_actor", "/business/business_location", "/business/employer", "/sports/sports_team", "/sports/sports_league", "/venture_capital/venture_investor", "/fictional_universe/fictional_setting", "/organization/organization", "/award/award_presenting_organization", "/award/award_winner", "/government/governmental_jurisdiction", "/location/statistical_region", "/organization/organization_scope", "/location/dated_location", "/travel/travel_destination", "/sports/professional_sports_team", "/location/place_with_neighborhoods", "/sports/sports_team_location", "/location/hud_foreclosure_area", "/location/hud_county_place", "/business/business_operation", "/business/customer", "/tv/tv_subject", "/periodicals/newspaper_circulation_area", "/location/capital_of_administrative_division", "/location/hud_county_place/place", "/sports/sports_league_participation/team-/sports/sports_league_participation/league"], "target_edge_id": 47179849, "target_edge_label": "/sports/sports_team/location"}
{"context_ids": [108, 112, 117, 217, 234, 250, 349, 421, 453, 469, 528, 582, 628, 763, 764, 940, 1076, 1095, 1125, 1188, 1235, 1373, 1415, 1640, 1715, 1810, 1819, 1913, 1948, 2025, 47179849, 47185555], "context_labels": ["/location/citytown", "/location/location", "/film/film_location", "/basketball/basketball_team", "/book/book_subject", "/tv/tv_actor", "/business/business_location", "/business/employer", "/sports/sports_team", "/sports/sports_league", "/venture_capital/venture_investor", "/fictional_universe/fictional_setting", "/organization/organization", "/award/award_presenting_organization", "/award/award_winner", "/government/governmental_jurisdiction", "/location/statistical_region", "/organization/organization_scope", "/location/dated_location", "/travel/travel_destination", "/sports/professional_sports_team", "/location/place_with_neighborhoods", "/sports/sports_team_location", "/location/hud_foreclosure_area", "/location/hud_county_place", "/business/business_operation", "/business/customer", "/tv/tv_subject", "/periodicals/newspaper_circulation_area", "/location/capital_of_administrative_division", "/sports/sports_team/location", "/sports/sports_league_participation/team-/sports/sports_league_participation/league"], "target_edge_id": 47183681, "target_edge_label": "/location/hud_county_place/place"}
{"context_ids": [108, 112, 117, 217, 234, 250, 349, 421, 453, 469, 528, 582, 628, 763, 764, 940, 1076, 1095, 1125, 1188, 1235, 1373, 1415, 1640, 1715, 1810, 1819, 1913, 1948, 2025, 47179849, 47183681], "context_labels": ["/location/citytown", "/location/location", "/film/film_location", "/basketball/basketball_team", "/book/book_subject", "/tv/tv_actor", "/business/business_location", "/business/employer", "/sports/sports_team", "/sports/sports_league", "/venture_capital/venture_investor", "/fictional_universe/fictional_setting", "/organization/organization", "/award/award_presenting_organization", "/award/award_winner", "/government/governmental_jurisdiction", "/location/statistical_region", "/organization/organization_scope", "/location/dated_location", "/travel/travel_destination", "/sports/professional_sports_team", "/location/place_with_neighborhoods", "/sports/sports_team_location", "/location/hud_foreclosure_area", "/location/hud_county_place", "/business/business_operation", "/business/customer", "/tv/tv_subject", "/periodicals/newspaper_circulation_area", "/location/capital_of_administrative_division", "/sports/sports_team/location", "/location/hud_county_place/place"], "target_edge_id": 47185555, "target_edge_label": "/sports/sports_league_participation/team-/sports/sports_league_participation/league"}

{'context_ids': [108,
  112,
  117,
  217,
  234,
  250,
  349,
  421,
  453,
  469,
  528,
  582,
  628,
  763,
  764,
  940,
  1076,
  1095,
  1125,
  1188,
  1235,
  1373,
  1415,
  1640,
  1715,
  1810,
  1819,
  1913,
  1948,
  2025,
  47179849,
  47183681],
 'context_labels': ['/location/citytown',
  '/location/location',
  '/film/film_location',
  '/basketball/basketball_team',
  '/book/book_subject',
  '/tv/tv_actor',
  '/business/business_location',
  '/business/employer',
  '/sports/sports_team',
  '/sports/sports_league',
  '/venture_capital/venture_investor',
  '/fictional_universe/fictional_setting',
  '/organization/organization',
  '/award/award_presenting_organization',
  '/award/award_winner',
  '/government/governmental_jurisdiction',
  '/location/statistical_region',
  '/organization/organization_scope',
  '/location/dated_location',
  '/travel/travel_destination',
  '/sports/professional_sports_team',
  '/location/place_with_neighborhoods',
  '/sports/sports_team_lo