In [7]:
import warnings
warnings.filterwarnings('ignore')  # Suppress all warnings

In [8]:
from transformers import TFBertForSequenceClassification, BertTokenizer
import tensorflow as tf
import gradio as gr

In [9]:
# Load the fine-tuned model and tokenizer from the directory
model = TFBertForSequenceClassification.from_pretrained('fine_tuned_bert_model')
tokenizer = BertTokenizer.from_pretrained('fine_tuned_bert_model')

Some layers from the model checkpoint at fine_tuned_bert_model were not used when initializing TFBertForSequenceClassification: ['dropout_37']
- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertForSequenceClassification were initialized from the model checkpoint at fine_tuned_bert_model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertForSequenceClassification for predictions without further training.


In [10]:
# Define the category mapping
category_mapping = {0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'}

In [11]:
# Function to predict the topic
def predict_topic(text):
    inputs = tokenizer(text, return_tensors="tf", truncation=True, padding=True, max_length=512)
    outputs = model(inputs)
    prediction = tf.argmax(outputs.logits, axis=-1).numpy()[0]
    return category_mapping[prediction]

In [12]:
# Create Gradio interface
def interface(text):
    predicted_topic = predict_topic(text)
    return f'Predicted Topic: {predicted_topic}'

In [13]:
# Gradio interface with text input and label output
gr.Interface(fn=interface, 
             inputs="text", 
             outputs="text", 
             title="Topic Classification with BERT", 
             description="Enter text to classify its topic as World, Sports, Business, or Sci/Tech",
             examples=[
                 ["The soccer team is preparing for the upcoming World Cup with intense training sessions."],
                 ["The stock market is experiencing a significant downturn."],
                 ["NASA launched a new satellite to monitor climate change."],
                 ["The latest smartphone has groundbreaking features and design."]
             ]
).launch()

Running on local URL:  http://127.0.0.1:7863

To create a public link, set `share=True` in `launch()`.


