In [13]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Embedding, Bidirectional, LSTM, Dense, TimeDistributed, Layer

# ----------- Custom CRF Layer -----------
class CRF(Layer):
    def __init__(self, num_tags, **kwargs):
        super(CRF, self).__init__(**kwargs)
        self.num_tags = num_tags

    def build(self, input_shape):
        # Transition matrix
        self.transitions = self.add_weight(
            shape=(self.num_tags, self.num_tags),
            initializer="glorot_uniform",
            trainable=True,
            name="transitions"
        )
        super(CRF, self).build(input_shape)

    def call(self, logits):
        return logits  # raw scores

    def get_loss(self, y_true, y_pred):
        return tf.keras.losses.sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)

    def viterbi_decode(self, logits):
        """Run Viterbi decoding on logits"""
        return tf.argmax(logits, axis=-1)   # simplified greedy decode
        # (For true CRF, implement tf_adjacency + dynamic programming)
        

# ----------- Model Definition -----------
max_len = 100
n_words = 5000
n_tags = 17

input = Input(shape=(max_len,))
x = Embedding(input_dim=n_words, output_dim=50, input_length=max_len, mask_zero=True)(input)
x = Bidirectional(LSTM(units=50, return_sequences=True))(x)
logits = TimeDistributed(Dense(n_tags))(x)

crf = CRF(n_tags)
out = crf(logits)

model = Model(inputs=input, outputs=out)
model.compile(optimizer='adam',
              loss=crf.get_loss,
              metrics=['accuracy'])





In [14]:
idx2tag = {0: "O", 1: "B-PER", 2: "I-PER", 3: "B-LOC", 4: "I-LOC"}  # example
words = ["John", "lives", "in", "New", "York"]


In [15]:
# Dummy logits (model.predict would give this)
dummy_logits = np.random.rand(1, len(words), len(idx2tag))

# Decode using CRF layer
pred_ids = crf.viterbi_decode(dummy_logits)
pred_ids = pred_ids.numpy()[0]  # (seq_len,)

# Map ids → tags
pred_tags = [idx2tag[i] for i in pred_ids]

# Show table
for w, t in zip(words, pred_tags):
    print(f"{w:10} {t}")


John       B-PER
lives      O
in         I-PER
New        O
York       B-LOC
