### Load model and tokenizer

In [0]:
import numpy as np
import tensorflow as tf
from tokenizers import Tokenizer
import pickle
import numpy as np

import json

class InstitutionModelCache:
    model = None
    tokenizer = None
    idx_to_inst = None
    inverted_idx_to_inst = None
    full_affiliation_dict = None

    @classmethod
    def load(cls):
        if cls.model is None:
            cls.model = tf.saved_model.load("/Volumes/openalex/works/models/institution_tagger_v2/basic_model")
            cls.tokenizer = Tokenizer.from_file("/Volumes/openalex/works/models/institution_tagger_v2/basic_model_tokenizer")

            with open("/Volumes/openalex/works/models/institution_tagger_v2/affiliation_vocab.pkl", "rb") as f:
                cls.idx_to_inst = pickle.load(f)
            cls.inverted_idx_to_inst = {v: k for k, v in cls.idx_to_inst.items()}

            with open("/Volumes/openalex/works/models/institution_tagger_v2/full_affiliation_dict.pkl", "rb") as f:
                cls.full_affiliation_dict = pickle.load(f)

def infer_institutions_batch(
    raw_affiliations,
    model_predict_fn,
    tokenizer,
    inverted_idx_to_inst,
    full_affiliation_dict,
    max_len=128,
    top_n=5,
    batch_size=20
):
    """
    Processes affiliation strings in batches and returns top-N institution metadata.

    Args:
        raw_affiliations (List[str]): Input affiliation strings
        model_predict_fn: TensorFlow signature (e.g., model.signatures["serving_default"])
        tokenizer: Loaded HuggingFace `tokenizers.Tokenizer`
        inverted_idx_to_inst (dict): Mapping from model index → institution_id
        full_affiliation_dict (dict): Mapping from institution_id → metadata
        max_len (int): Max token length
        top_n (int): Number of top results to return
        batch_size (int): Mini-batch size for inference

    Returns:
        List[List[dict]]: A list (one per input) of top-N metadata dicts
    """
    def tokenize_and_pad(text):
        ids = np.array(tokenizer.encode(text).ids[:max_len], dtype=np.int64)
        padded = np.zeros(max_len, dtype=np.int64)
        padded[:len(ids)] = ids
        return padded

    results = []
    for i in range(0, len(raw_affiliations), batch_size):
        batch = raw_affiliations[i:i+batch_size]
        input_ids = np.array([tokenize_and_pad(aff) for aff in batch], dtype=np.int64)
        input_tensor = tf.convert_to_tensor(input_ids, dtype=tf.int64)

        logits = model_predict_fn(tokenized_aff_string_input=input_tensor)["cls"].numpy()

        for logit_vector in logits:
            top_indices = logit_vector.argsort()[-top_n:][::-1]
            top_matches = []
            for idx in top_indices:
                inst_id = inverted_idx_to_inst[idx]
                metadata = full_affiliation_dict.get(inst_id, {})
                top_matches.append({
                    "id": inst_id,
                    "score": float(logit_vector[idx]),
                    "display_name": metadata.get("display_name"),
                    "ror_id": metadata.get("ror_id"),
                    "other_names": metadata.get("final_names"),
                    "type": metadata.get("types"),
                })
            results.append(top_matches)

    return results

def process_affiliation_batch(buffer, predict_fn, tokenizer, inverted_idx_to_inst, full_affiliation_dict, batch_size, affiliation_column_name):
    aff_batch = [r[affiliation_column_name] for r in buffer]
    batch_results = infer_institutions_batch(
        raw_affiliations=aff_batch,
        model_predict_fn=predict_fn,
        tokenizer=tokenizer,
        inverted_idx_to_inst=inverted_idx_to_inst,
        full_affiliation_dict=full_affiliation_dict,
        batch_size=batch_size
    )
    for raw_aff, match_list in zip(aff_batch, batch_results):
        yield {
            "raw_affiliation_string": raw_aff,
            "model_response": match_list  # Already a list of dicts
        }

def process_partition(rows, affiliation_column_name = "raw_affiliation_string", batch_size = 20):
    # All happens inside the executor process now
    InstitutionModelCache.load()
    predict_fn = InstitutionModelCache.model.signatures["serving_default"]
    tokenizer = InstitutionModelCache.tokenizer
    inverted_idx_to_inst = InstitutionModelCache.inverted_idx_to_inst
    full_affiliation_dict = InstitutionModelCache.full_affiliation_dict

    # Buffer for streaming batch
    buffer = []
    for row in rows:
        buffer.append(row)
        if len(buffer) == batch_size:
            yield from process_affiliation_batch(buffer, predict_fn, tokenizer,
                inverted_idx_to_inst, full_affiliation_dict, batch_size, affiliation_column_name)
            buffer = []

    if buffer:
            yield from process_affiliation_batch(buffer, predict_fn, tokenizer,
                inverted_idx_to_inst, full_affiliation_dict, batch_size, affiliation_column_name)

### Migrated, refactored and combined Jason's post-processing functions

In [0]:
from typing import List, Optional
from pyspark.sql.types import *
from pyspark.sql.functions import *

from affiliation_string_parsing import match_affiliation_to_institution_ids, process_current_affiliation_with_ids

@udf(returnType=ArrayType(LongType()))
def override_institution_ids(
    raw_affiliation_string: str,
    model_current_affs: Optional[List[int]]
) -> List[int]:
    # String-based matching
    string_matches = match_affiliation_to_institution_ids(raw_affiliation_string)
    
    # Combine with model-inferred affiliations
    combined_affs = (string_matches or []) + (model_current_affs or [])

    # Disambiguation rules based on combined list
    updated_affs = process_current_affiliation_with_ids(combined_affs, raw_affiliation_string)

    # Final cleanup
    filtered_affs = [x for x in updated_affs if x != -1] if len(updated_affs) > 1 else updated_affs

    # Step 5: Skip override if it's identical to model's output
    if set(filtered_affs) == set(model_current_affs or []):
        return []

    return filtered_affs or [-1]


### Load data for processing

In [0]:
df = spark.sql(""" 
        SELECT raw_affiliation_string
        FROM openalex.institutions.affiliation_strings_lookup 
        WHERE institution_ids IS NULL and model_response is NULL
        -- this really helps (if clustered by the column) with preventing Spark from sampling rows randomly and causing a shuffle
        ORDER BY raw_affiliation_string 
        LIMIT 1000000;
""").repartition(48).cache()

print(f"Total rows to process: {df.count()}")
display(df) # trigger repartitioning and caching

# df = spark.createDataFrame([("Uniwersytet Mikołaja Kopernika w Toruniu, Klinika Medycyny Ratunkowej Collegium Medicum w Bydgoszczy",)], 
#                            ["raw_affiliation_string"])

### Run inference via `mapPartitions`

In [0]:
from pyspark.sql.types import *

schema = StructType([
    StructField("raw_affiliation_string", StringType(), nullable=False),
    StructField("model_response", ArrayType(
        StructType([
            StructField("id", StringType(), nullable=False),
            StructField("score", DoubleType(), nullable=False),
            StructField("display_name", StringType(), nullable=True),
            StructField("ror_id", StringType(), nullable=True),
            StructField("other_names", ArrayType(StringType()), nullable=True),
            StructField("type", StringType(), nullable=True),
        ])
    ), nullable=True)
])

res_rdd = df.rdd.mapPartitions(process_partition)
res_df = spark.createDataFrame(res_rdd, schema).cache()
res_df.createOrReplaceTempView("model_results")

print(f"Total rows inferred: {res_df.count()}")
display(res_df)

### Merge `model_response` to lookup table 

In [0]:
%sql
MERGE INTO openalex.institutions.affiliation_strings_lookup AS target
USING model_results AS source
ON target.model_response IS NULL 
  AND target.institution_ids IS NULL 
  AND target.raw_affiliation_string = source.raw_affiliation_string
WHEN MATCHED THEN
  UPDATE SET target.model_response = source.model_response, 
    target.source = 'walden',
    target.updated_datetime = current_timestamp();

### 

### Filter by `score > 0.1` but keep the first `institution_id` if filter removes all

In [0]:
%sql
with inst_ids AS (
  SELECT
     raw_affiliation_string,
     transform(
          filter(model_response, 
            x -> x.score > 0.1
          ), 
          x -> x.id
        ) AS scored_institution_ids,
      model_response.id as model_ids
  FROM openalex.institutions.affiliation_strings_lookup
  WHERE model_response IS NOT NULL and source = 'walden'
),
selected_ids AS (
  SELECT
    raw_affiliation_string,
    CASE WHEN size(scored_institution_ids) < 1 THEN SLICE(model_ids, 1, 1)
      ELSE scored_institution_ids END AS institution_ids
  FROM inst_ids
)
MERGE INTO openalex.institutions.affiliation_strings_lookup AS target
USING selected_ids AS source
ON target.institution_ids IS NULL
  AND target.source = 'walden'
  AND target.raw_affiliation_string = source.raw_affiliation_string
WHEN MATCHED THEN
  UPDATE SET target.institution_ids = source.institution_ids;
-- select size(institution_ids), count(*) from selected_ids
-- group by size(institution_ids)
-- order by size(institution_ids)

### Execute migrated `institution_ids` override logic

In [0]:
# Load the source table
df = spark.sql("""
               SELECT * FROM openalex.institutions.affiliation_strings_lookup 
               WHERE model_response is not null 
                and institution_ids is not null
                and institution_ids_override is null
               """)

# Apply UDF to compute updated overrides
df_with_override = df.withColumn(
    "institution_ids_override",
    override_institution_ids(col("raw_affiliation_string"), col("institution_ids"))
)

# Only update rows where the override is non-null (optional)
df_with_override.createOrReplaceTempView("override_updates")

# Perform the merge update
df = spark.sql("""
    MERGE INTO openalex.institutions.affiliation_strings_lookup AS target
    USING override_updates AS source
    ON target.model_response IS NOT NULL AND target.raw_affiliation_string = source.raw_affiliation_string
    WHEN MATCHED THEN
      UPDATE SET target.institution_ids_override = source.institution_ids_override
""")
display(df)

#### `institution_id` size distribution (similar to PROD)

In [0]:
%sql
SELECT size(institution_ids), count(*) 
FROM openalex.institutions.affiliation_strings_lookup
WHERE source = 'walden'
GROUP by size(institution_ids)

#### `institution_ids_override` size distribution after applying logic (5%+ affected)

In [0]:
%sql
SELECT size(institution_ids_override), count(*) 
FROM openalex.institutions.affiliation_strings_lookup
WHERE source = 'walden'
GROUP by size(institution_ids_override)