In [None]:
from sentence_transformers import SentenceTransformer
import pandas as pd
import ast
from collections import Counter
from tqdm import tqdm
import os

# load the dataset
data = pd.read_csv(f'..data/articles.csv')
data["meshroot"] = data["meshroot"].apply(ast.literal_eval)  # apply literal_eval because lists appear as strings

all_labels = [label for sublist in data['meshroot'] for label in sublist]
label_counts = Counter(all_labels)
labels = [label[0] for label in label_counts.items()]

embedder = SentenceTransformer("neuml/pubmedbert-base-embeddings")
X_all = []
y_all = []

for label in tqdm(labels):
  positive_samples = data[data["meshroot"].apply(lambda c: label in c)] # get all the rows that contain the label
  positive_samples = positive_samples[:min(10000,len(positive_samples))] #min(10000,len(positive_samples))
  negative_samples = data[data["meshroot"].apply(lambda c: label not in c)]
  negative_samples = negative_samples.sample(n=len(positive_samples), replace=True) # sample the same number of rows that do not contain it

  X = []
  for text in positive_samples["abstractText"]:
      X.append(embedder.encode(text))
  for text in negative_samples["abstractText"]:
      X.append(embedder.encode(text))

  y = [label] * len(positive_samples) + ['0'] * len(negative_samples) # we use "0" as a negative label

  X_all.append(X)
  y_all.append(y)


In [None]:
import pickle

# Save the lists to a file
with open(f'{path}X_title.pkl', 'wb') as f:
    pickle.dump(X_all, f)
with open(f'{path}y_title.pkl', 'wb') as f:
    pickle.dump(y_all, f)