In [None]:
# app.py

import gradio as gr
import torch
from transformers import AutoTokenizer
from model import Classifier

# 設置模型和參數
model_name = 'bert-base-uncased'
max_len = 160
class_names = ['negative', 'neutral', 'positive']

# 載入模型和 tokenizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = Classifier(n_classes=3, model_name=model_name)
model.load_state_dict(torch.load("model/best_model_state.bin", map_location=device))
model = model.to(device)
model.eval()

def predict_sentiment(review):
    # 將輸入文本編碼
    encoding = tokenizer.encode_plus(
        review,
        add_special_tokens=True,
        max_length=max_len,
        return_token_type_ids=False,
        padding='max_length',
        return_attention_mask=True,
        truncation=True,
        return_tensors='pt'
    )

    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    # 預測
    with torch.no_grad():
        outputs = model(input_ids, attention_mask=attention_mask)
        _, prediction = torch.max(outputs, dim=1)
        sentiment = class_names[prediction]

    return sentiment

# 建立 Gradio 介面
interface = gr.Interface(
    fn=predict_sentiment,
    inputs=gr.Dropdown(choices=["The app is excellent!! ",                    
                "Not Free. Free upto only 5 Habbits.",
                "This app is very bad"],
                label="Select a Review"),
    outputs="text",
    title="Sentiment Analysis",
    description="輸入一篇評論，判別其為正向、中立或負向"
)

# 啟動介面
if __name__ == "__main__":
    interface.launch()


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

  model.load_state_dict(torch.load("model/best_model_state.bin", map_location=device))


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.
