# Imports

In [3]:
import os
os.environ["TRANSFORMERS_NO_TF"] = "1"

from transformers import AutoTokenizer, AutoModelForSequenceClassification


In [4]:
from pathlib import Path
import json

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score

from datasets import Dataset

DATA_PATH = Path("../data/processed/iclr2024.json")
MODEL_NAME = "bert-base-uncased"
RANDOM_SEED = 42


# Load Data & Inspect

In [10]:
def load_jsonl(path: Path):
    records = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            records.append(json.loads(line))
    return records

records = load_jsonl(DATA_PATH)
len(records), records[0]


(7404,
 {'venue': 'ICLR.cc',
  'year': '2024',
  'paper_id': 'cXs5md5wAq',
  'title': 'Modelling Microbial Communities with Graph Neural Networks',
  'abstract': 'Understanding the interactions and interplay of microorganisms is a great challenge with many applications in medical and environmental settings. In this work, we model bacterial communities directly from their genomes using graph neural networks (GNNs). GNNs leverage the inductive bias induced by the set nature of bacteria, enforcing permutation invariance and granting combinatorial generalization. We propose to learn the dynamics implicitly by directly predicting community relative abundance profiles at steady state, thus escaping the need for growth curves. On two real-world datasets, we show for the first time generalization to unseen bacteria and different community structures. \nTo investigate the prediction results more deeply, we created a simulation for flexible data generation and analyze effects of bacteria interac

In [11]:
examples = []

for r in records:
    title = (r.get("title") or "").strip()
    abstract = (r.get("abstract") or "").strip()
    label = r.get("label")

    # Skip if something critical is missing
    if not title or not abstract:
        continue
    if label is None:
        continue

    text = f"{title} {abstract}".strip()
    examples.append({"text": text, "label": int(label)})

len(examples)


7404

In [12]:
import collections

label_counts = collections.Counter(e["label"] for e in examples)
label_counts, {k: v / len(examples) for k, v in label_counts.items()}

(Counter({0: 5143, 1: 2261}), {0: 0.69462452728255, 1: 0.30537547271745})

# Prepare Training Data

In [13]:
def make_splits(examples, test_size=0.2, val_size=0.2, seed=RANDOM_SEED):
    texts = [e["text"] for e in examples]
    labels = [e["label"] for e in examples]

    X_train, X_temp, y_train, y_temp = train_test_split(
        texts, labels,
        test_size=test_size,
        random_state=seed,
        stratify=labels,
    )

    val_ratio = val_size / (1.0 - test_size)
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp,
        test_size=1 - val_ratio,
        random_state=seed,
        stratify=y_temp,
    )

    def to_hf(texts, labels):
        return Dataset.from_dict({"text": texts, "label": labels})

    return (
        to_hf(X_train, y_train),
        to_hf(X_val, y_val),
        to_hf(X_test, y_test),
    )

train_ds, val_ds, test_ds = make_splits(examples)
len(train_ds), len(val_ds), len(test_ds)

(5923, 370, 1111)

In [14]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize(batch):
    return tokenizer(
        batch["text"],
        truncation=True,
        padding="max_length",
        max_length=256,
    )

tokenized_train = train_ds.map(tokenize, batched=True)
tokenized_val = val_ds.map(tokenize, batched=True)
tokenized_test = test_ds.map(tokenize, batched=True)

for ds in [tokenized_train, tokenized_val, tokenized_test]:
    ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/5923 [00:00<?, ? examples/s]

Map:   0%|          | 0/370 [00:00<?, ? examples/s]

Map:   0%|          | 0/1111 [00:00<?, ? examples/s]

# Build Model

In [15]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=2,
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    # softmax for positive class (label=1)
    exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
    probs = exp_logits[:, 1] / exp_logits.sum(axis=1)

    preds = (probs >= 0.5).astype(int)

    metrics = {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds),
    }
    try:
        metrics["auroc"] = roc_auc_score(labels, probs)
    except ValueError:
        metrics["auroc"] = float("nan")
    return metrics

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
