In [4]:
import tensorflow as tf
from conner import ConNER

In [5]:
def train_model(train_texts, epochs=300, batch_size=16, save_dir=None):
    """Train ConNER."""
    print("Initializing model...")
    model = ConNER()

    print("Preparing data...")
    inputs, labels = model.prepare_data(train_texts)

    # Create dataset
    dataset = tf.data.Dataset.from_tensor_slices((inputs, labels))
    dataset = dataset.shuffle(1000).batch(batch_size)

    # Create optimizer and loss
    optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, weight_decay=0.01)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    # Metrics
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

    # Training step
    @tf.function
    def train_step(x, y):
        with tf.GradientTape() as tape:
            # Forward pass
            logits = model(x, training=True)

            # Calculate loss
            mask = tf.cast(x["attention_mask"], tf.float32)
            loss = loss_fn(y, logits, sample_weight=mask)

        # Calculate gradients
        gradients = tape.gradient(loss, model.trainable_variables)

        # Apply gradients
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Update metrics
        train_accuracy.update_state(y, logits, sample_weight=mask)

        return loss

    print("\nStarting training...")
    for epoch in range(epochs):
        print(f"\nEpoch {epoch + 1}/{epochs}")

        # Reset metrics
        train_accuracy.reset_state()
        total_loss = 0
        num_batches = 0

        for x, y in dataset:
            loss = train_step(x, y)
            total_loss += loss
            num_batches += 1

        # Print metrics
        avg_loss = total_loss / num_batches
        print(f"Average Loss: {avg_loss:.4f}")
        print(f"Accuracy: {train_accuracy.result():.4f}")

    if save_dir:
        model.save_model(save_dir)

    return model

In [6]:
# Load training data
print("Loading training data...")
with open("data/train.txt") as f:
    train_texts = f.readlines()

print(f"Loaded training samples")

save_path = "saved_models/conner"

# Train model
model = train_model(train_texts, epochs=300, batch_size=16, save_dir=save_path)
print("\nModel training completed successfully!")

Loading training data...
Loaded training samples
Initializing model...


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'bert.embeddings.position_ids', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the 

Preparing data...

Starting training...

Epoch 1/300


2024-11-28 07:56:09.866125: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Average Loss: 0.2282
Accuracy: 0.2139

Epoch 2/300
Average Loss: 0.2008
Accuracy: 0.2763

Epoch 3/300
Average Loss: 0.1769
Accuracy: 0.3317

Epoch 4/300
Average Loss: 0.1635
Accuracy: 0.3784

Epoch 5/300
Average Loss: 0.1463
Accuracy: 0.4532

Epoch 6/300
Average Loss: 0.1325
Accuracy: 0.5077

Epoch 7/300
Average Loss: 0.1218
Accuracy: 0.5644

Epoch 8/300
Average Loss: 0.1140
Accuracy: 0.6038

Epoch 9/300
Average Loss: 0.1087
Accuracy: 0.6483

Epoch 10/300
Average Loss: 0.1038
Accuracy: 0.6725

Epoch 11/300
Average Loss: 0.1023
Accuracy: 0.6880

Epoch 12/300
Average Loss: 0.0961
Accuracy: 0.7168

Epoch 13/300
Average Loss: 0.0958
Accuracy: 0.7186

Epoch 14/300
Average Loss: 0.0936
Accuracy: 0.7316

Epoch 15/300
Average Loss: 0.0926
Accuracy: 0.7343

Epoch 16/300
Average Loss: 0.0912
Accuracy: 0.7495

Epoch 17/300


2024-11-28 07:56:15.737248: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Average Loss: 0.0910
Accuracy: 0.7564

Epoch 18/300
Average Loss: 0.0872
Accuracy: 0.7679

Epoch 19/300
Average Loss: 0.0885
Accuracy: 0.7676

Epoch 20/300
Average Loss: 0.0857
Accuracy: 0.7752

Epoch 21/300
Average Loss: 0.0843
Accuracy: 0.7816

Epoch 22/300
Average Loss: 0.0829
Accuracy: 0.7825

Epoch 23/300
Average Loss: 0.0835
Accuracy: 0.7773

Epoch 24/300
Average Loss: 0.0836
Accuracy: 0.7822

Epoch 25/300
Average Loss: 0.0821
Accuracy: 0.7773

Epoch 26/300
Average Loss: 0.0811
Accuracy: 0.7825

Epoch 27/300
Average Loss: 0.0798
Accuracy: 0.7831

Epoch 28/300
Average Loss: 0.0783
Accuracy: 0.7928

Epoch 29/300
Average Loss: 0.0784
Accuracy: 0.7898

Epoch 30/300
Average Loss: 0.0772
Accuracy: 0.7925

Epoch 31/300
Average Loss: 0.0766
Accuracy: 0.7958

Epoch 32/300
Average Loss: 0.0748
Accuracy: 0.7895

Epoch 33/300
Average Loss: 0.0754
Accuracy: 0.7873

Epoch 34/300
Average Loss: 0.0743
Accuracy: 0.8016

Epoch 35/300
Average Loss: 0.0729
Accuracy: 0.8010

Epoch 36/300
Average Loss

2024-11-28 07:56:27.184951: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Average Loss: 0.0637
Accuracy: 0.8158

Epoch 50/300
Average Loss: 0.0630
Accuracy: 0.8125

Epoch 51/300
Average Loss: 0.0627
Accuracy: 0.8140

Epoch 52/300
Average Loss: 0.0632
Accuracy: 0.8137

Epoch 53/300
Average Loss: 0.0612
Accuracy: 0.8185

Epoch 54/300
Average Loss: 0.0605
Accuracy: 0.8155

Epoch 55/300
Average Loss: 0.0589
Accuracy: 0.8213

Epoch 56/300
Average Loss: 0.0586
Accuracy: 0.8219

Epoch 57/300
Average Loss: 0.0593
Accuracy: 0.8228

Epoch 58/300
Average Loss: 0.0588
Accuracy: 0.8234

Epoch 59/300
Average Loss: 0.0582
Accuracy: 0.8216

Epoch 60/300
Average Loss: 0.0582
Accuracy: 0.8222

Epoch 61/300
Average Loss: 0.0564
Accuracy: 0.8279

Epoch 62/300
Average Loss: 0.0578
Accuracy: 0.8270

Epoch 63/300
Average Loss: 0.0566
Accuracy: 0.8307

Epoch 64/300
Average Loss: 0.0548
Accuracy: 0.8288

Epoch 65/300
Average Loss: 0.0549
Accuracy: 0.8304

Epoch 66/300
Average Loss: 0.0557
Accuracy: 0.8279

Epoch 67/300
Average Loss: 0.0548
Accuracy: 0.8325

Epoch 68/300
Average Loss

2024-11-28 07:56:50.968705: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Average Loss: 0.0413
Accuracy: 0.8700

Epoch 114/300
Average Loss: 0.0426
Accuracy: 0.8622

Epoch 115/300
Average Loss: 0.0403
Accuracy: 0.8782

Epoch 116/300
Average Loss: 0.0413
Accuracy: 0.8664

Epoch 117/300
Average Loss: 0.0404
Accuracy: 0.8700

Epoch 118/300
Average Loss: 0.0409
Accuracy: 0.8700

Epoch 119/300
Average Loss: 0.0398
Accuracy: 0.8746

Epoch 120/300
Average Loss: 0.0410
Accuracy: 0.8694

Epoch 121/300
Average Loss: 0.0396
Accuracy: 0.8785

Epoch 122/300
Average Loss: 0.0407
Accuracy: 0.8713

Epoch 123/300
Average Loss: 0.0396
Accuracy: 0.8761

Epoch 124/300
Average Loss: 0.0394
Accuracy: 0.8752

Epoch 125/300
Average Loss: 0.0407
Accuracy: 0.8755

Epoch 126/300
Average Loss: 0.0402
Accuracy: 0.8749

Epoch 127/300
Average Loss: 0.0388
Accuracy: 0.8734

Epoch 128/300
Average Loss: 0.0381
Accuracy: 0.8806

Epoch 129/300
Average Loss: 0.0387
Accuracy: 0.8755

Epoch 130/300
Average Loss: 0.0386
Accuracy: 0.8785

Epoch 131/300
Average Loss: 0.0387
Accuracy: 0.8770

Epoch 1

2024-11-28 07:57:44.309752: I tensorflow/core/framework/local_rendezvous.cc:405] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Average Loss: 0.0292
Accuracy: 0.9112

Epoch 242/300
Average Loss: 0.0303
Accuracy: 0.9043

Epoch 243/300
Average Loss: 0.0307
Accuracy: 0.9025

Epoch 244/300
Average Loss: 0.0310
Accuracy: 0.9015

Epoch 245/300
Average Loss: 0.0308
Accuracy: 0.9028

Epoch 246/300
Average Loss: 0.0297
Accuracy: 0.9022

Epoch 247/300
Average Loss: 0.0305
Accuracy: 0.9058

Epoch 248/300
Average Loss: 0.0315
Accuracy: 0.9015

Epoch 249/300
Average Loss: 0.0305
Accuracy: 0.9025

Epoch 250/300
Average Loss: 0.0309
Accuracy: 0.9061

Epoch 251/300
Average Loss: 0.0303
Accuracy: 0.9031

Epoch 252/300
Average Loss: 0.0296
Accuracy: 0.9037

Epoch 253/300
Average Loss: 0.0295
Accuracy: 0.9085

Epoch 254/300
Average Loss: 0.0294
Accuracy: 0.9109

Epoch 255/300
Average Loss: 0.0294
Accuracy: 0.9109

Epoch 256/300
Average Loss: 0.0297
Accuracy: 0.9073

Epoch 257/300
Average Loss: 0.0314
Accuracy: 0.8991

Epoch 258/300
Average Loss: 0.0300
Accuracy: 0.9082

Epoch 259/300
Average Loss: 0.0298
Accuracy: 0.9103

Epoch 2