In [78]:
import polars as pl
import pandas as pd
import re
from tqdm import tqdm
import numpy as np
import networkx as nx
import random

extract_hard = False
extract_medium = False
extract_easy = True

In [2]:
# import 2 sources
df_concepts = pl.read_parquet("D:/finetune_sbert/pre_with_expression.parquet")
df_relations = pd.read_csv("D:/finetune_sbert/connectivity_2025-04-01.csv")
df_relations['src.id'] = df_relations['src.id'].astype(str)
df_relations['dst.id'] = df_relations['dst.id'].astype(str)
df_relations = pl.from_pandas(df_relations)

  df_relations = pd.read_csv("D:/finetune_sbert/connectivity_2025-04-01.csv")


In [3]:
list_sem_tag = df_concepts['sem_tag'] .unique().to_list() + ["attribute"]
escaped_tags = [re.escape(tag) for tag in list_sem_tag]
pattern = r"\s*\((" + "|".join(escaped_tags) + r")\)"


In [4]:
df_concepts = (((df_concepts.filter(pl.col("concept_type") == "SCT_PRE")
               .filter(pl.col("status") == "defined")
               .with_columns(pl.col("expression").str.replace_all(r"\d+\|", "|").alias("expression_cleaned")) # remove digits
               ).with_columns(pl.col("expression_cleaned").str.replace_all(pattern, "").alias("expression_cleaned_no_semtag")))
               .with_columns(pl.col("n.label").str.replace_all(pattern, "").alias("label_no_semtag"))
               .unique())


In [5]:
df_relations_is_a = df_relations.filter(pl.col("type(r)") == "IS_A")
df_concepts_pos = df_concepts.select(pl.col("label_no_semtag"), pl.col("expression_cleaned_no_semtag"), pl.col("id") )
df_relations_is_a_def = ((df_relations_is_a
 .select(pl.col("src.id"),pl.col("dst.id"))
 .join(df_concepts_pos, left_on= "src.id", right_on="id", how="left")
 .drop_nulls()
 .rename({"expression_cleaned_no_semtag": "src.expression_cleaned_no_semtag","label_no_semtag": "src.label_no_semtag"}))
 .join(df_concepts_pos, left_on= "dst.id", right_on="id", how="left")
 .drop_nulls()
 .rename({"expression_cleaned_no_semtag": "dst.expression_cleaned_no_semtag","label_no_semtag": "dst.label_no_semtag"})
 ).unique().select(pl.col("src.id"),pl.col("dst.id"))

In [12]:
df_concepts.write_csv("D:/finetune_sbert_new/concept_info/concept_info_pre_fully_def_all.csv")

In [30]:
df_concepts.filter(pl.col("id") == "113091000")['term'].to_list()

['MRI',
 'MRI - Magnetic resonance imaging',
 'Magnetic resonance imaging (procedure)',
 'Magnetic resonance imaging',
 'NMR - Nuclear magnetic resonance',
 'Magnetic resonance study',
 'MR - Magnetic resonance']

# build dicts
- id_to_exp
- id_to_label (both FSN and synonym)

In [70]:
id_to_expr = dict(zip(df_concepts["id"], df_concepts["expression_cleaned_no_semtag"]))

df_concept_fsn_syn = df_concepts.group_by("id").agg([
    pl.concat_list([pl.col("label_no_semtag"), pl.col("term")])
      .list.explode()      # Flatten all nested lists into individual elements
      .unique()            # Get unique values             
    .alias("merged_terms")
])
# df_concept_fsn_syn['label_no_semtag'].to_list()
id_to_label_syn = dict(zip(df_concepts["id"], df_concept_fsn_syn["merged_terms"].to_list()))



# samples

In [73]:
def hard_triplet(df_concepts, df_relations_is_a_def):
    id_concept = df_concepts["id"].unique()
    seen = set()
    rows_hard = []

    # Build parent -> children and child -> parents maps
    child_to_parents = {}
    parent_to_children = {}

    for row in df_relations_is_a_def.iter_rows(named=True):
        child = row['src.id']
        parent = row['dst.id']
        child_to_parents.setdefault(child, []).append(parent)
        parent_to_children.setdefault(parent, []).append(child)

    for id_anchor in tqdm(id_concept):
        if id_anchor not in child_to_parents:
            continue

        for parent in child_to_parents[id_anchor]:
            siblings = parent_to_children.get(parent, [])
            for sibling_id in siblings:
                if sibling_id == id_anchor:
                    continue
                pair = (id_anchor, sibling_id)
                if pair in seen:
                    continue
                seen.add(pair)
                try:
                    rows_hard.append({
                        "anchor" : id_anchor,
                        "positive": id_anchor,
                        "negative": sibling_id,
                        "level": "hard"
                    })
                except KeyError:
                    continue  # skip missing values

    return rows_hard
rows_hard = hard_triplet(df_concepts, df_relations_is_a_def)

100%|██████████| 145567/145567 [00:15<00:00, 9132.91it/s] 


In [79]:
def medium_triplet(df_concepts, df_relations_is_a_def):
    df_relations_is_a_def = df_relations_is_a_def.unique()
    src = df_relations_is_a_def['src.id'].to_list()
    dst = df_relations_is_a_def['dst.id'].to_list()

    edges = list(zip(src, dst))
    G = nx.Graph()
    G.add_edges_from(edges)
    
    id_concept = df_concepts["id"].unique()
    id_to_semtag = dict(zip(df_concepts["id"], df_concepts["sem_tag"]))

    rows_medium = []

    for id_anchor in tqdm(id_concept):
        if id_anchor not in G:
            continue

        anchor_tag = id_to_semtag.get(id_anchor)
        if anchor_tag is None:
            continue

        nearby_concepts = nx.single_source_shortest_path_length(G, id_anchor, cutoff=3)
        level_3_candidates = [
            cid for cid, dist in nearby_concepts.items()
            if dist == 3 and id_to_semtag.get(cid) == anchor_tag
        ]
        if not level_3_candidates:
            continue

        sampled_negatives = random.sample(level_3_candidates, min(100, len(level_3_candidates)))

        for candidate_id in sampled_negatives:
            try:
                rows_medium.append({
                    "anchor": id_anchor,
                    "positive": id_anchor,
                    "negative": candidate_id,
                    "level": "medium"
                })
            except KeyError:
                continue  # skip missing data

    return rows_medium

rows_medium = medium_triplet(df_concepts, df_relations_is_a_def)


100%|██████████| 145567/145567 [02:42<00:00, 898.36it/s] 


In [82]:
def easy_triplet(df_concepts):
    id_concept = df_concepts["id"].unique()
    
    rows_easy = []

    # Build fast access lookup tables
    id_to_top_cat = dict(zip(df_concepts["id"], df_concepts["top_category"]))
  
    # Pre-group concept IDs by top category
    top_cat_to_ids = {}
    for id_, top_cat in id_to_top_cat.items():
        top_cat_to_ids.setdefault(top_cat, []).append(id_)

    # Flatten all IDs once for sampling
    all_ids = set(id_concept)

    for id_anchor in tqdm(id_concept):
        anchor_top_cat = id_to_top_cat.get(id_anchor)
        if anchor_top_cat is None:
            continue

        # Get IDs that are NOT in the same top category
        in_same_cat = set(top_cat_to_ids.get(anchor_top_cat, []))
        negative_candidates = list(all_ids - in_same_cat)

        if not negative_candidates:
            continue

        sampled_negatives = np.random.choice(
            negative_candidates,
            size=min(50, len(negative_candidates)),
            replace=False
        )

        for neg_id in sampled_negatives:
            try:
                rows_easy.append({
                    "anchor": id_anchor,
                    "positive": id_anchor,
                    "negative": neg_id,
                    "level": "easy"
                })
            except KeyError:
                continue  # Skip missing values

    return rows_easy

rows_easy = easy_triplet(df_concepts)


  0%|          | 225/145567 [01:24<15:04:51,  2.68it/s] 


KeyboardInterrupt: 

In [None]:
rows_medium

[{'anchor': '239163008',
  'positive': '239163008',
  'negative': '1290786007',
  'level': 'medium'},
 {'anchor': '239163008',
  'positive': '239163008',
  'negative': '1163215007',
  'level': 'medium'},
 {'anchor': '239163008',
  'positive': '239163008',
  'negative': '359757004',
  'level': 'medium'},
 {'anchor': '239163008',
  'positive': '239163008',
  'negative': '398262004',
  'level': 'medium'},
 {'anchor': '239163008',
  'positive': '239163008',
  'negative': '87225004',
  'level': 'medium'},
 {'anchor': '239163008',
  'positive': '239163008',
  'negative': '735912006',
  'level': 'medium'},
 {'anchor': '239163008',
  'positive': '239163008',
  'negative': '30321000175106',
  'level': 'medium'},
 {'anchor': '239163008',
  'positive': '239163008',
  'negative': '297961003',
  'level': 'medium'},
 {'anchor': '239163008',
  'positive': '239163008',
  'negative': '56940005',
  'level': 'medium'},
 {'anchor': '239163008',
  'positive': '239163008',
  'negative': '723166008',
  'leve