In [36]:
# =========================
# Semantic Similarity w/ BERT (SNLI)
# =========================

# ---- MUST be first ----
import os
os.environ["KERAS_BACKEND"] = "tensorflow"

# ---- Imports ----
import math
import numpy as np
import pandas as pd
import tensorflow as tf
import tf_keras as keras
from tf_keras import layers
from transformers import BertTokenizerFast, TFBertModel

# ---- Config ----
max_length = 128
batch_size = 32
epochs = 2
labels = ["contradiction", "entailment", "neutral"]
num_classes = 3
lab2id = {"contradiction":0, "entailment":1, "neutral":2}

# ---- Data: SNLI (100k train for demo) ----
# (Uncomment these two lines if running in Colab/Notebook)
# !curl -LO https://raw.githubusercontent.com/MohamadMerchant/SNLI/master/data.tar.gz
# !tar -xvzf data.tar.gz

train_df = pd.read_csv("SNLI_Corpus/snli_1.0_train.csv", nrows=100000)
valid_df = pd.read_csv("SNLI_Corpus/snli_1.0_dev.csv")
test_df  = pd.read_csv("SNLI_Corpus/snli_1.0_test.csv")

print(f"Total train samples : {train_df.shape[0]}")
print(f"Total validation samples: {valid_df.shape[0]}")
print(f"Total test samples: {test_df.shape[0]}")

# Basic cleaning
train_df.dropna(axis=0, inplace=True)
valid_df.dropna(axis=0, inplace=True)
test_df.dropna(axis=0,  inplace=True)

# Keep only valid labels & shuffle
train_df = train_df[train_df.similarity.isin(labels)].sample(frac=1.0, random_state=42).reset_index(drop=True)
valid_df = valid_df[valid_df.similarity.isin(labels)].sample(frac=1.0, random_state=42).reset_index(drop=True)
test_df  = test_df [test_df .similarity.isin(labels)].reset_index(drop=True)

# Map to ints and one-hot (weâ€™ll use categorical_crossentropy)
train_df["label"] = train_df["similarity"].map(lab2id)
valid_df["label"] = valid_df["similarity"].map(lab2id)
test_df["label"]  = test_df ["similarity"].map(lab2id)

y_train = keras.utils.to_categorical(train_df.label, num_classes=num_classes).astype("float32")
y_val   = keras.utils.to_categorical(valid_df.label, num_classes=num_classes).astype("float32")
y_test  = keras.utils.to_categorical(test_df.label,  num_classes=num_classes).astype("float32")

# ---- Generator ----
class BertSemanticDataGenerator(keras.utils.Sequence):
    """
    Returns:
      ([input_ids, attention_mask, token_type_ids], labels) if include_targets=True
      or just [input_ids, attention_mask, token_type_ids]
    """
    def __init__(self, sentence_pairs, labels=None, batch_size=batch_size,
                 shuffle=True, include_targets=True, max_length=max_length):
        self.sentence_pairs = np.asarray(sentence_pairs)
        self.labels = None if labels is None else np.asarray(labels, dtype="float32")
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.include_targets = include_targets
        self.max_length = max_length
        self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", do_lower_case=True)
        self.indexes = np.arange(len(self.sentence_pairs))
        self.on_epoch_end()

    def __len__(self):
        return math.ceil(len(self.sentence_pairs) / self.batch_size)

    def __getitem__(self, idx):
        sl = slice(idx * self.batch_size, (idx + 1) * self.batch_size)
        batch_idx = self.indexes[sl]
        pairs = self.sentence_pairs[batch_idx]

        enc = self.tokenizer.batch_encode_plus(
            pairs.tolist(),
            add_special_tokens=True,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=True,
            return_tensors="np",
        )

        input_ids      = enc["input_ids"].astype("int32")
        attention_mask = enc["attention_mask"].astype("int32")
        token_type_ids = enc["token_type_ids"].astype("int32")

        if self.include_targets and self.labels is not None:
            batch_labels = self.labels[batch_idx]  # (B, num_classes) one-hot
            return [input_ids, attention_mask, token_type_ids], batch_labels
        else:
            return [input_ids, attention_mask, token_type_ids]

    def on_epoch_end(self):
        if self.shuffle:
            rng = np.random.default_rng(42)
            rng.shuffle(self.indexes)

# ---- Strategy (safe single-device) ----
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try: tf.config.experimental.set_memory_growth(gpus[0], True)
    except: pass
strategy = tf.distribute.OneDeviceStrategy("/GPU:0" if gpus else "/CPU:0")
print("Strategy:", strategy)

# ---- Model ----
with strategy.scope():
    input_ids      = layers.Input((max_length,), dtype=tf.int32, name="input_ids")
    attention_mask = layers.Input((max_length,), dtype=tf.int32, name="attention_mask")
    token_type_ids = layers.Input((max_length,), dtype=tf.int32, name="token_type_ids")

    # Load TF model; pull weights from PyTorch to avoid safetensors-handle issues
    bert = TFBertModel.from_pretrained("bert-base-uncased", from_pt=True)
    bert.trainable = False  # feature extraction first

    outputs = bert(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
        training=False,
    )
    sequence_output = outputs.last_hidden_state
    pooled_output   = outputs.pooler_output

    # Optional: BiLSTM on sequence_output
    x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(sequence_output)
    avg_pool = layers.GlobalAveragePooling1D()(x)
    max_pool = layers.GlobalMaxPooling1D()(x)
    h = layers.Concatenate()([avg_pool, max_pool])  # (you can also concat pooled_output)
    h = layers.Dropout(0.3)(h)
    logits = layers.Dense(num_classes, activation="softmax")(h)

    model = keras.Model(
        inputs=[input_ids, attention_mask, token_type_ids],
        outputs=logits,
        name="bert_semantic_similarity",
    )
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=2e-4),
        loss="categorical_crossentropy",   # one-hot labels
        metrics=["accuracy"],
    )

model.summary()

# ---- Generators ----
train_data = BertSemanticDataGenerator(
    train_df[["sentence1", "sentence2"]].values.astype("str"),
    y_train, batch_size=batch_size, shuffle=True,  max_length=max_length
)
valid_data = BertSemanticDataGenerator(
    valid_df[["sentence1", "sentence2"]].values.astype("str"),
    y_val,   batch_size=batch_size, shuffle=False, max_length=max_length
)

# ---- Train (feature extraction) ----
history = model.fit(
    train_data,
    validation_data=valid_data,
    epochs=epochs,
    workers=1,                 # threads; safer with HF tokenizer
    use_multiprocessing=False, # avoid pickling/tokenizer issues
    verbose=2,
)

# ---- Fine-tune (optional) ----
with strategy.scope():
    model.layers[3].trainable = True  # unfreeze bert backbone (index may differ; ensure correct)
    model.compile(
        optimizer=keras.optimizers.Adam(1e-5),
        loss="categorical_crossentropy",
        metrics=["accuracy"],
    )
model.summary()

history_ft = model.fit(
    train_data,
    validation_data=valid_data,
    epochs=epochs,
    workers=1,
    use_multiprocessing=False,
    verbose=2,
)

# ---- Evaluate ----
test_data = BertSemanticDataGenerator(
    test_df[["sentence1", "sentence2"]].values.astype("str"),
    y_test, batch_size=batch_size, shuffle=False, max_length=max_length
)
model.evaluate(test_data, verbose=1)

# ---- Inference helper ----
def check_similarity(sentence1, sentence2):
    pairs = np.array([[str(sentence1), str(sentence2)]])
    tmp = BertSemanticDataGenerator(pairs, labels=None, batch_size=1, shuffle=False, include_targets=False)
    proba = model.predict(tmp[0], verbose=0)[0]
    idx = int(np.argmax(proba))
    return labels[idx], f"{proba[idx]*100:.2f}%"

# Quick sanity checks
print(check_similarity("Two women are observing something together.",
                       "Two women are standing with their eyes closed."))
print(check_similarity("A smiling costumed woman is holding an umbrella",
                       "A happy woman in a fairy costume holds an umbrella"))
print(check_similarity("A soccer game with multiple males playing",
                       "Some men are playing a sport"))


Total train samples : 100000
Total validation samples: 10000
Total test samples: 10000
Strategy: <tensorflow.python.distribute.one_device_strategy.OneDeviceStrategy object at 0x7a7b54532bd0>


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already

Model: "bert_semantic_similarity"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_ids (InputLayer)      [(None, 128)]                0         []                            
                                                                                                  
 attention_mask (InputLayer  [(None, 128)]                0         []                            
 )                                                                                                
                                                                                                  
 token_type_ids (InputLayer  [(None, 128)]                0         []                            
 )                                                                                                
                                                                           



3122/3122 - 2581s - loss: 0.4511 - accuracy: 0.8276 - val_loss: 0.3494 - val_accuracy: 0.8702 - 2581s/epoch - 827ms/step
Epoch 2/2
3122/3122 - 2525s - loss: 0.2753 - accuracy: 0.9048 - val_loss: 0.3535 - val_accuracy: 0.8743 - 2525s/epoch - 809ms/step
('contradiction', '81.68%')
('neutral', '97.38%')
('entailment', '94.54%')
