# Convert to TFlite Model

In [None]:
import tensorflow as tf
from transformers import TFAutoModelForSequenceClassification, AutoTokenizer
import os

MODEL_DIR       = "/content/drive/MyDrive/Colab Notebooks/Actionable-Fine-Tune/mobilebert-finetuned-actionable-v2"
MAX_LEN         = 128                    # sequence length used during fine-tuning
SAVEDMODEL_DIR  = "mobilebert_savedmodel_f32"
TFLITE_OUT      = "mobilebert_float32.tflite"

# -----------------------------------------------------------
# 1️⃣  LOAD MODEL & TOKENIZER (PyTorch → TF)
# -----------------------------------------------------------
model = TFAutoModelForSequenceClassification.from_pretrained(
    MODEL_DIR,
    from_pt=True                        # convert safetensors → TensorFlow
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

# -----------------------------------------------------------
# 2️⃣  EXPORT AS SAVEDMODEL WITH CORRECT SIGNATURE
# -----------------------------------------------------------
if os.path.exists(SAVEDMODEL_DIR):
    !rm -rf $SAVEDMODEL_DIR

@tf.function(input_signature=[
    tf.TensorSpec([None, MAX_LEN], tf.int32, name="input_ids"),
    tf.TensorSpec([None, MAX_LEN], tf.int32, name="attention_mask"),
    tf.TensorSpec([None, MAX_LEN], tf.int32, name="token_type_ids"),
])
def serving_fn(input_ids, attention_mask, token_type_ids):
    return model(
        input_ids=input_ids,
        attention_mask=attention_mask,
        token_type_ids=token_type_ids,
    )

tf.saved_model.save(
    model,
    SAVEDMODEL_DIR,
    signatures={"serving_default": serving_fn},
)
print("✅ SavedModel exported →", SAVEDMODEL_DIR)

# -----------------------------------------------------------
# 3️⃣  CONVERT TO FLOAT-32 TFLITE  (no quantisation)
# -----------------------------------------------------------
converter = tf.lite.TFLiteConverter.from_saved_model(SAVEDMODEL_DIR)
# No optimisations / quant flags → pure float32 graph
tflite_model = converter.convert()

with open(TFLITE_OUT, "wb") as f:
    f.write(tflite_model)

print(f"🎉  Float32 TFLite model written →  {TFLITE_OUT}")


All PyTorch model weights were used when initializing TFMobileBertForSequenceClassification.

All the weights of TFMobileBertForSequenceClassification were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFMobileBertForSequenceClassification for predictions without further training.


✅ SavedModel exported → mobilebert_savedmodel_f32
🎉  Float32 TFLite model written →  mobilebert_float32.tflite


# Test Intefence

In [None]:
# ---------------------------------------------------------
# 0.  Imports & paths
# ---------------------------------------------------------
import tensorflow as tf
import numpy as np
import pandas as pd
from transformers import AutoTokenizer
import scipy.special, re
from tqdm import tqdm

MODEL_DIR       = "/content/drive/MyDrive/Colab Notebooks/Actionable-Fine-Tune/mobilebert-finetuned-actionable-v2"
TFLITE_PATH     = "mobilebert_float32.tflite"
CSV_PATH        = "data.csv"  # must contain 'text' and 'label'
MAX_LEN         = 128

# ---------------------------------------------------------
# 1.  Load tokenizer
# ---------------------------------------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)

# ---------------------------------------------------------
# 2.  Load TFLite model and resize inputs
# ---------------------------------------------------------
interpreter = tf.lite.Interpreter(model_path=TFLITE_PATH)
interpreter.allocate_tensors()

for inp in interpreter.get_input_details():
    interpreter.resize_tensor_input(inp["index"], [1, MAX_LEN])
interpreter.allocate_tensors()

input_details  = interpreter.get_input_details()
output_details = interpreter.get_output_details()

def _base_key(tflite_name: str) -> str:
    return re.sub(r"^serving_default_|:\d+$", "", tflite_name)

# ---------------------------------------------------------
# 3.  Inference function
# ---------------------------------------------------------
def predict(text: str):
    enc = tokenizer(
        text,
        padding="max_length",
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="np",
    )

    for inp in input_details:
        idx   = inp["index"]
        dtype = inp["dtype"]
        key   = _base_key(inp["name"])
        tensor = enc[key]

        if dtype in (np.int32, np.int64):
            interpreter.set_tensor(idx, tensor.astype(dtype))
        else:
            scale, zp = inp["quantization"]
            q = np.round(tensor.astype(np.float32) / scale + zp).astype(dtype)
            interpreter.set_tensor(idx, q)

    interpreter.invoke()

    out_info = output_details[0]
    raw = interpreter.get_tensor(out_info["index"])
    if out_info["dtype"] in (np.int8, np.uint8):
        scale, zp = out_info["quantization"]
        logits = (raw.astype(np.float32) - zp) * scale
    else:
        logits = raw.astype(np.float32)

    probs = scipy.special.softmax(logits, axis=-1)[0]
    pred  = int(np.argmax(probs))
    return pred, probs

# ---------------------------------------------------------
# 4.  Load dataset and evaluate accuracy
# ---------------------------------------------------------
df = pd.read_csv(CSV_PATH)
texts  = df["text"].astype(str).tolist()
labels = df["label"].astype(int).tolist()

correct = 0
total   = 0
mistakes = []

for text, label in tqdm(zip(texts, labels), total=len(texts)):
    pred, _ = predict(text)
    if pred == label:
        correct += 1
    else:
        mistakes.append((text, label, pred))
    total += 1

acc = correct / total
print(f"\n✅ Accuracy on {total} examples: {acc * 100:.2f}%")

# Optional: Show some mistakes
print("\n❌ Sample misclassifications:")
for i in range(min(5, len(mistakes))):
    t, true, pred = mistakes[i]
    print(f"  • '{t[:60]}...' → true={true}, pred={pred}")


100%|██████████| 1301/1301 [02:59<00:00,  7.24it/s]


✅ Accuracy on 1301 examples: 96.00%

❌ Sample misclassifications:
  • 'What’s playing at the cinema?...' → true=1, pred=0
  • 'How many steps have I taken today?...' → true=1, pred=0
  • 'Where’s the nearest gas station?...' → true=1, pred=0
  • 'What time does the show start?...' → true=0, pred=1
  • 'Is it windy outside?...' → true=1, pred=0



