In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import precision_recall_fscore_support
from gensim.models import KeyedVectors
from TorchCRF import CRF
from collections import Counter
from tqdm import tqdm

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cpu


## Load dataset

In [2]:
from utils.fonctions import (
    load_jnlpba_dataset,
    load_ncbi_dataset,
    prepare_ncbi_for_ner
)

jnlpba_data, _ = load_jnlpba_dataset("./datasets/JNLPBA")

ncbi_raw = load_ncbi_dataset("./datasets/NCBI-Corpus/")
ncbi_data = prepare_ncbi_for_ner(ncbi_raw)

print("JNLPBA:", len(jnlpba_data))
print("NCBI:", len(ncbi_data))

Chargement du dataset JNLPBA depuis: ./datasets/JNLPBA
- sentences: 50421 phrases
- Classes: ['B-DNA', 'I-DNA', 'B-cell_line', 'I-cell_line', 'B-protein', 'I-protein', 'B-cell_type', 'I-cell_type', 'B-RNA', 'I-RNA', 'O']
Chargement du dataset NCBI depuis: ./datasets/NCBI-Corpus/
Documents chargés: 793
Exemple d'entités dans le premier document: 2
JNLPBA: 50421
NCBI: 7526


## NORMALIZATION

In [3]:
def normalize_dataset(dataset):
    out = []
    for sent in dataset:
        s = []
        for w, l in sent:
            w = w.strip() if w.strip() else "<UNK>"
            l = l.strip()
            s.append((w, l))
        out.append(s)
    return out

jnlpba_data = normalize_dataset(jnlpba_data)
ncbi_data   = normalize_dataset(ncbi_data)

## Merge Datasets

In [4]:
all_data = jnlpba_data + ncbi_data
random.shuffle(all_data)

print("Total sentences:", len(all_data))

Total sentences: 57947


## VOCABULARIES

In [5]:
from collections import Counter

def build_word_vocab(data):
    counter = Counter(w for sent in data for w, _ in sent)
    vocab = {"<PAD>": 0, "<UNK>": 1}
    for w in counter:
        vocab[w] = len(vocab)
    return vocab

def build_char_vocab(data):
    chars = set(c for sent in data for w, _ in sent for c in w)
    vocab = {"<PAD>": 0, "<UNK>": 1}
    for c in chars:
        vocab[c] = len(vocab)
    return vocab

def build_label_vocab(data):
    labels = sorted(set(l for sent in data for _, l in sent))
    vocab = {"<PAD>": 0}
    for l in labels:
        vocab[l] = len(vocab)
    return vocab

word_vocab  = build_word_vocab(all_data)
char_vocab  = build_char_vocab(all_data)
label_vocab = build_label_vocab(all_data)

print("Word vocab:", len(word_vocab))
print("Char vocab:", len(char_vocab))
print("Labels:", label_vocab)

Word vocab: 26994
Char vocab: 88
Labels: {'<PAD>': 0, 'B-CompositeMention': 1, 'B-DNA': 2, 'B-DiseaseClass': 3, 'B-Modifier': 4, 'B-RNA': 5, 'B-SpecificDisease': 6, 'B-cell_line': 7, 'B-cell_type': 8, 'B-protein': 9, 'I-CompositeMention': 10, 'I-DNA': 11, 'I-DiseaseClass': 12, 'I-Modifier': 13, 'I-RNA': 14, 'I-SpecificDisease': 15, 'I-cell_line': 16, 'I-cell_type': 17, 'I-protein': 18, 'O': 19}


## Rare Entity manipulation

In [6]:
# import torch
# import torch.nn as nn
# from collections import Counter

# # Count the occurrences of each label
# label_counts = Counter()
# for sent in all_data:  # your encoded dataset
#     label_counts.update(sent["labels"])

# # Compute inverse frequency weights
# total = sum(label_counts.values())
# class_weights = [total / (label_counts[i] + 1e-8) for i in range(len(label_vocab))]
# class_weights = torch.FloatTensor(class_weights).to(device)

# # Use with CrossEntropyLoss
# criterion = nn.CrossEntropyLoss(weight=class_weights)

### Class-specific sampling / oversampling

In [7]:
rare_labels = ["B-SpecificDisease", "B-CompositeMention", "B-DiseaseClass"]
rare_indices = [i for i, s in enumerate(all_data)
                if any(lbl in rare_labels for _, lbl in s)]

# Oversample rare sentences
oversample_dataset = all_data + [all_data[i] for i in rare_indices]


In [8]:
# from torch.utils.data import DataLoader, WeightedRandomSampler

# # Example weights: inverse frequency of sentences containing rare entities
# sample_weights = [2.0 if i in rare_indices else 1.0 for i in range(len(oversample_dataset))]
# sampler = WeightedRandomSampler(sample_weights, num_samples=len(oversample_dataset), replacement=True)

# train_loader = DataLoader(oversample_dataset, batch_size=32, sampler=sampler)


### Data augmentation

In [9]:
from collections import defaultdict

entity_dict = defaultdict(list)
for sentence in all_data:
    for token, label in sentence:
        if label.startswith("B-"):
            entity_dict[label].append(token)

In [10]:
import random

def augment_entity(sentence, entity_dict):
    new_sentence = []
    for token, label in sentence:
        if label.startswith("B-") and label in entity_dict and entity_dict[label]:
            # randomly replace with another entity of the same type
            token = random.choice(entity_dict[label])
        new_sentence.append((token, label))
    return new_sentence

data_augmented = [augment_entity(s, entity_dict) for s in all_data]

## ENCODING

In [11]:
MAX_CHAR_LEN = 30

def encode(sent):
    words, chars, labels = [], [], []
    for w, l in sent:
        words.append(word_vocab.get(w, word_vocab["<UNK>"]))

        c = [char_vocab.get(ch, char_vocab["<UNK>"]) for ch in w][:MAX_CHAR_LEN]
        c += [char_vocab["<PAD>"]] * (MAX_CHAR_LEN - len(c))
        chars.append(c)

        labels.append(label_vocab[l])

    return {
        "words": words,
        "chars": chars,
        "labels": labels
    }

## Encode + Split

In [12]:
encoded_all = [encode(s) for s in data_augmented]
random.shuffle(encoded_all)

split = int(0.9 * len(encoded_all))
train_data = encoded_all[:split]
val_data   = encoded_all[split:]

print("Train:", len(train_data))
print("Val:", len(val_data))

Train: 52152
Val: 5795


### Compute rare indices and sample weights only for training set

In [13]:
from torch.utils.data import DataLoader, WeightedRandomSampler

rare_labels = ["B-SpecificDisease", "B-CompositeMention", "B-DiseaseClass"]
label_vocab_inv = {v: k for k, v in label_vocab.items()}


rare_indices_train = [
    i for i, s in enumerate(train_data)
    if any(label_vocab_inv[lbl] in rare_labels for lbl in s["labels"])
]

sample_weights = [2.0 if i in rare_indices_train else 1.0 for i in range(len(train_data))]

sampler = WeightedRandomSampler(sample_weights, num_samples=len(train_data), replacement=True)


## DATASET & COLLATE

In [14]:
class NERDataset(Dataset):
    def __init__(self, data): self.data = data
    def __len__(self): return len(self.data)
    def __getitem__(self, i): return self.data[i]

def collate(batch):
    max_len = max(len(x["words"]) for x in batch)
    pad = lambda x, v: x + [v] * (max_len - len(x))
    words = [pad(x["words"], 0) for x in batch]
    labels = [pad(x["labels"], 0) for x in batch]
    chars = [x["chars"] + [[0]*MAX_CHAR_LEN]*(max_len-len(x["chars"])) for x in batch]
    mask = [[1]*len(x["words"]) + [0]*(max_len-len(x["words"])) for x in batch]

    return (
        torch.tensor(words),
        torch.tensor(chars),
        torch.tensor(labels),
        torch.tensor(mask, dtype=torch.bool)
    )

train_loader = DataLoader(NERDataset(train_data), batch_size=32, sampler=sampler, collate_fn=collate)
val_loader   = DataLoader(NERDataset(val_data), 32, False, collate_fn=collate)

## Word Embedding with  BioWordVec

In [15]:
kv = KeyedVectors.load("./embeddings/biowordvec.gensim", mmap="r")
EMB_DIM = kv.vector_size

emb_matrix = np.random.normal(0, 0.6, (len(word_vocab), EMB_DIM))
for w, i in word_vocab.items():
    if w in kv:
        emb_matrix[i] = kv[w]

word_embeddings = nn.Embedding.from_pretrained(
    torch.tensor(emb_matrix, dtype=torch.float),
    freeze=False
)

## Char Encoders

In [16]:
class CharCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(len(char_vocab), 30, padding_idx=0)
        self.convs = nn.ModuleList([
            nn.Conv1d(30, 50, k, padding=k//2) for k in (3,4,5)
        ])

    def forward(self, x):
        B, T, L = x.shape
        x = self.emb(x).view(B*T, L, -1).transpose(1,2)
        feats = [torch.max(torch.relu(conv(x)), 2)[0] for conv in self.convs]
        return torch.cat(feats, 1).view(B, T, -1)

class CharBiLSTM(nn.Module):
    def __init__(self):
        super().__init__()
        self.emb = nn.Embedding(len(char_vocab), 30, padding_idx=0)
        self.lstm = nn.LSTM(30, 50, bidirectional=True, batch_first=True)

    def forward(self, x):
        B, T, L = x.shape
        x = self.emb(x).view(B*T, L, -1)
        _, (h, _) = self.lstm(x)
        return torch.cat([h[0], h[1]], 1).view(B, T, -1)

## MANHATTAN ATTENTION

In [17]:
class ManhattanAttention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.W = nn.Linear(dim, 1, bias=False)

    def forward(self, h, mask):
        B, T, D = h.shape
        hi = h.unsqueeze(2).expand(B, T, T, D)
        hj = h.unsqueeze(1).expand(B, T, T, D)
        dist = torch.abs(hi - hj).sum(-1)
        score = -self.W(hj).squeeze(-1) * dist
        score = score.masked_fill(~mask.unsqueeze(1), -1e9)
        alpha = torch.softmax(score, -1)
        ctx = torch.matmul(alpha, h)
        return torch.cat([h, ctx], -1)

## FULL MODEL

In [18]:
class BioNER(nn.Module):
    def __init__(self):
        super().__init__()
        self.word_emb = word_embeddings
        self.char_cnn = CharCNN()
        self.char_lstm = CharBiLSTM()

        concat_dim = EMB_DIM + 150 + 100
        self.embed_fc = nn.Linear(concat_dim, 200)

        self.bilstm = nn.LSTM(200, 256, bidirectional=True, batch_first=True)
        self.attn = ManhattanAttention(512)
        self.fc = nn.Linear(1024, len(label_vocab))
        self.crf = CRF(len(label_vocab), batch_first=True)

    def forward(self, w, c, mask, labels=None):
        we = self.word_emb(w)
        ce1 = self.char_cnn(c)
        ce2 = self.char_lstm(c)

        x = torch.cat([we, ce1, ce2], -1)
        x = self.embed_fc(x)

        h, _ = self.bilstm(x)
        h = self.attn(h, mask)
        emissions = self.fc(h)

        if labels is not None:
            return -self.crf(emissions, labels, mask)
        return self.crf.decode(emissions, mask)

## TRAINING & METRICS

In [19]:
def metrics(y_true, y_pred):
    p, r, f, _ = precision_recall_fscore_support(y_true, y_pred, average="micro")
    acc = sum(t == p for t, p in zip(y_true, y_pred)) / len(y_true)
    return acc, p, r, f

def evaluate(model, loader):
    model.eval()
    loss, yt, yp = 0, [], []
    with torch.no_grad():
        for w, c, l, m in loader:
            w, c, l, m = w.to(device), c.to(device), l.to(device), m.to(device)
            loss += model(w, c, m, l).item()
            preds = model(w, c, m)
            for p, g, mask in zip(preds, l, m):
                L = mask.sum().item()
                yp.extend(p[:L])
                yt.extend(g[:L].tolist())
    acc, p, r, f = metrics(yt, yp)
    return loss/len(loader), acc, p, r, f

## TRAIN LOOP

In [20]:
import torch
import torch.optim as optim

model = BioNER().to(device)
opt = optim.Adam(model.parameters(), lr=2e-4)

best_val_acc = 0
wait = 0
patience = 5
num_epochs = 30

for epoch in range(num_epochs):
    # ---- training ----
    model.train()
    tr_loss = 0.0
    n_batches = 0

    for w, c, l, m in train_loader:
        w, c, l, m = w.to(device), c.to(device), l.to(device), m.to(device)

        opt.zero_grad()
        loss = model(w, c, m, l)  # forward with labels
        loss.backward()
        opt.step()

        tr_loss += loss.item()
        n_batches += 1

    tr_loss /= n_batches  # average training loss

    # ---- validation ----
    model.eval()
    val_loss = 0.0
    yt_val, yp_val = [], []
    n_val_batches = 0

    with torch.no_grad():
        for w, c, l, m in val_loader:
            w, c, l, m = w.to(device), c.to(device), l.to(device), m.to(device)
            
            # loss computation
            loss = model(w, c, m, l)
            val_loss += loss.item()
            n_val_batches += 1

            # predictions
            preds = model(w, c, m)
            for p, g, mask in zip(preds, l, m):
                L = mask.sum().item()
                yp_val.extend(p[:L])
                yt_val.extend(g[:L].tolist())

    val_loss /= n_val_batches
    val_acc = (torch.tensor(yp_val)[:len(yt_val)] == torch.tensor(yt_val)).float().mean().item()

    print(f"Epoch {epoch+1:02d} | Train L:{tr_loss:.4f} || Val L:{val_loss:.4f} Acc:{val_acc:.4f}")

    # ---- early stopping ----
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        wait = 0
        torch.save(model.state_dict(), "best_bioner.pt")
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping")
            break

Epoch 01 | Train L:312.0564 || Val L:175.5166 Acc:0.9243
Epoch 02 | Train L:147.3941 || Val L:132.0257 Acc:0.9391
Epoch 03 | Train L:110.6399 || Val L:115.4803 Acc:0.9426
Epoch 04 | Train L:87.0613 || Val L:91.4574 Acc:0.9536
Epoch 05 | Train L:71.6095 || Val L:79.8151 Acc:0.9579
Epoch 06 | Train L:58.9517 || Val L:70.9379 Acc:0.9608
Epoch 07 | Train L:49.2656 || Val L:63.9244 Acc:0.9638
Epoch 08 | Train L:42.0047 || Val L:59.0211 Acc:0.9667
Epoch 09 | Train L:34.8477 || Val L:54.8994 Acc:0.9688
Epoch 10 | Train L:28.8983 || Val L:52.3155 Acc:0.9710
Epoch 11 | Train L:24.3814 || Val L:51.9269 Acc:0.9710
Epoch 12 | Train L:20.1358 || Val L:47.9246 Acc:0.9741
Epoch 13 | Train L:16.6448 || Val L:45.0855 Acc:0.9761
Epoch 14 | Train L:13.6066 || Val L:44.4027 Acc:0.9759
Epoch 15 | Train L:11.1613 || Val L:43.8076 Acc:0.9772
Epoch 16 | Train L:9.0042 || Val L:48.0408 Acc:0.9758
Epoch 17 | Train L:7.4325 || Val L:44.1408 Acc:0.9780
Epoch 18 | Train L:6.0605 || Val L:43.4781 Acc:0.9793
Epoch 1

In [21]:
def bio_to_spans(labels, id2label):
    """
    Convert a BIO tag sequence into entity spans.
    Returns: list of (entity_type, start, end)
    """
    spans = []
    start = None
    ent_type = None

    for i, lab_id in enumerate(labels):
        label = id2label[lab_id]

        if label == "O":
            if ent_type is not None:
                spans.append((ent_type, start, i - 1))
                ent_type = None
            continue

        tag, typ = label.split("-", 1)

        if tag == "B":
            if ent_type is not None:
                spans.append((ent_type, start, i - 1))
            ent_type = typ
            start = i

        elif tag == "I":
            if ent_type != typ:
                # BIO violation → start new entity
                if ent_type is not None:
                    spans.append((ent_type, start, i - 1))
                ent_type = typ
                start = i

    if ent_type is not None:
        spans.append((ent_type, start, len(labels) - 1))

    return spans





from collections import Counter

def entity_level_metrics(y_true, y_pred, id2label):
    """
    y_true, y_pred: lists of label-id sequences (one per sentence)
    """
    gold_spans = []
    pred_spans = []

    for gt, pr in zip(y_true, y_pred):
        gold_spans.extend(bio_to_spans(gt, id2label))
        pred_spans.extend(bio_to_spans(pr, id2label))

    gold_set = set(gold_spans)
    pred_set = set(pred_spans)

    tp = len(gold_set & pred_set)
    fp = len(pred_set - gold_set)
    fn = len(gold_set - pred_set)

    precision = tp / (tp + fp + 1e-8)
    recall    = tp / (tp + fn + 1e-8)
    f1        = 2 * precision * recall / (precision + recall + 1e-8)

    return precision, recall, f1




def entity_metrics_by_type(y_true, y_pred, id2label):
    gold_by_type = Counter()
    pred_by_type = Counter()
    correct_by_type = Counter()

    for gt, pr in zip(y_true, y_pred):
        gold_spans = bio_to_spans(gt, id2label)
        pred_spans = bio_to_spans(pr, id2label)

        gold_set = set(gold_spans)
        pred_set = set(pred_spans)

        for ent, _, _ in gold_set:
            gold_by_type[ent] += 1
        for ent, _, _ in pred_set:
            pred_by_type[ent] += 1
        for ent, _, _ in gold_set & pred_set:
            correct_by_type[ent] += 1

    results = {}
    for ent in gold_by_type:
        p = correct_by_type[ent] / (pred_by_type[ent] + 1e-8)
        r = correct_by_type[ent] / (gold_by_type[ent] + 1e-8)
        f = 2 * p * r / (p + r + 1e-8)
        results[ent] = (p, r, f)

    return results





import torch

# ----- 1. Load your label mapping -----
id2label = {v: k for k, v in label_vocab.items()}  # assuming you have label_vocab dict

# ----- 2. Load your trained model -----
model = BioNER().to(device)
model.load_state_dict(torch.load("best_bioner.pt", map_location=device))
model.eval()

# ----- 3. Define evaluation using entity-level metrics -----
def evaluate_entity_level(model, loader, id2label):
    all_gold, all_pred = [], []

    with torch.no_grad():
        for w, c, l, m in loader:
            w, c, l, m = w.to(device), c.to(device), l.to(device), m.to(device)
            preds = model(w, c, m)

            for p, g, mask in zip(preds, l, m):
                L = mask.sum().item()
                all_pred.append(p[:L])
                all_gold.append(g[:L].tolist())

    # Use entity-level metric function
    p, r, f = entity_level_metrics(all_gold, all_pred, id2label)
    per_type = entity_metrics_by_type(all_gold, all_pred, id2label)

    return p, r, f, per_type

# ----- 4. Run evaluation -----
precision, recall, f1, per_entity = evaluate_entity_level(model, val_loader, id2label)

print(f"Entity-level Precision: {precision:.4f}")
print(f"Entity-level Recall:    {recall:.4f}")
print(f"Entity-level F1:        {f1:.4f}\n")

print("Per-entity metrics:")
for ent, (p, r, f) in per_entity.items():
    print(f"{ent:12s} P:{p:.3f} R:{r:.3f} F1:{f:.3f}")


Entity-level Precision: 0.9411
Entity-level Recall:    0.9387
Entity-level F1:        0.9399

Per-entity metrics:
DNA          P:0.901 R:0.883 F1:0.892
protein      P:0.948 R:0.924 F1:0.936
cell_type    P:0.888 R:0.912 F1:0.900
cell_line    P:0.913 R:0.833 F1:0.871
SpecificDisease P:0.716 R:0.685 F1:0.700
RNA          P:0.900 R:0.892 F1:0.896
DiseaseClass P:0.511 R:0.442 F1:0.474
Modifier     P:0.685 R:0.570 F1:0.622
CompositeMention P:0.571 R:0.308 F1:0.400
