In [0]:
%load_ext autoreload
%autoreload 2
# Enables autoreload; learn more at https://docs.databricks.com/en/files/workspace-modules.html#autoreload-for-python-modules
# To disable autoreload; run %autoreload 0

In [0]:
from topic_predictor import create_model

In [0]:
import os
import pickle 
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, pipeline, AutoTokenizer
from sentence_transformers import SentenceTransformer

weights_path = "/Volumes/openalex/works/models/topic_classifier_v1/model_checkpoint/citation_part_only.keras"
# Define the path
model_path = "/Volumes/openalex/works/models/topic_classifier_v1/"

# Load the needed files
with open(os.path.join(model_path, "target_vocab.pkl"), "rb") as f:
    target_vocab = pickle.load(f)

print("Loaded target vocab")

with open(os.path.join(model_path, "inv_target_vocab.pkl"), "rb") as f:
    inv_target_vocab = pickle.load(f)

print("Loaded inverse target vocab")

with open(os.path.join(model_path, "citation_feature_vocab.pkl"), "rb") as f:
    citation_feature_vocab = pickle.load(f)
    
print("Loaded citation features vocab.")

with open(os.path.join(model_path, "gold_to_id_mapping_dict.pkl"), "rb") as f:
    gold_to_label_mapping = pickle.load(f)

print("Loaded gold citation mapping")

with open(os.path.join(model_path, "gold_citations_dict.pkl"), "rb") as f:
    gold_dict = pickle.load(f)
    
print("Loaded gold citation L1")

with open(os.path.join(model_path, "non_gold_citations_dict.pkl"), "rb") as f:
    non_gold_dict = pickle.load(f)

print("Loaded non-gold citation L1")

# Load the tokenizer and embedding model
emb_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
print("Loaded SentenceTransformer")
language_model_name = "OpenAlex/bert-base-multilingual-cased-finetuned-openalex-topic-classification-title-abstract"
tokenizer = AutoTokenizer.from_pretrained(language_model_name, truncate=True)
print("Loaded tokenizer")

In [0]:
# Loading the models
pred_model = create_model(len(target_vocab), 
                          len(citation_feature_vocab)+2,
                          weights_path, topk=5)
print("✅ Model created.")
xla_predict = tf.function(pred_model, jit_compile=True)

language_model = TFAutoModelForSequenceClassification.from_pretrained(language_model_name, output_hidden_states=True)
language_model.trainable = False
xla_predict_lang_model = tf.function(language_model, jit_compile=True)

# # Sending a blank prediction through the model in order to get it "warmed up"
# _ = xla_predict(create_input_feature([[101, 102] + [0]*510, 
#                                       [1, 1] + [0]*510,
#                                       [1]+[0]*15, 
#                                       [1]+[0]*127,
#                                       np.zeros(384, dtype=np.float32)]))
print("✅ Model initialized")


# model.save("/dbfs/models/citation_part_only_full.keras")
# print("✅ Full model saved.")

In [0]:
%fs mkdirs /tmp/topic_classifier_v1

In [0]:
# # /Volumes/openalex/works/models/topic_classifier_v1/full_model.keras
pred_model.save("/dbfs/tmp/full_model.keras")
print("✅ Full Keras model saved.")

#pred_model.save("/Volumes/openalex/works/models/topic_classifier_v1/tf_savedmodel", save_format="tf")
#print("✅ Saved model saved.")

In [0]:
%fs ls /Volumes/openalex/works/models/topic_classifier_v1/model_checkpoint/

In [0]:
from keras.models import load_model
#size 39,654,502 (checkpoint 39,650,086)
test_model = load_model("/Volumes/openalex/works/models/topic_classifier_v1/full_model.keras")


In [0]:
for layer in test_model.layers:
    if "output_layer" in layer.name:
        weights = layer.get_weights()
        print(f"{layer.name}: {[w.shape for w in weights]}")