In [1]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install transformers datasets fastapi uvicorn[standard] scikit-learn matplotlib


Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting httptools>=0.6.3 (from uvicorn[standard])
  Downloading httptools-0.7.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl.metadata (3.5 kB)
Collecting uvloop>=0.15.1 (from uvicorn[standard])
  Downloading uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (4.9 kB)
Collecting watchfiles>=0.13 (from uvicorn[standard])
  Downloading watchfiles-1.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.9 kB)
Downloading httptools-0.7.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl (517 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m517.7/517.7 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading uvloop-0.22.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl (4.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.4/4

In [6]:
#加载 IMDb 数据集
from datasets import load_dataset
dataset = load_dataset("imdb")

# 划分训练 / 测试集
train_ds = dataset['train']
test_ds = dataset['test']

# 取较小的子集加速训练（可取消）
train_ds = train_ds.shuffle(seed=42).select(range(20000))
test_ds = test_ds.shuffle(seed=42).select(range(5000))


In [3]:
#BERT Tokenizer 预处理
from transformers import AutoTokenizer

checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)

train_tokenized = train_ds.map(tokenize, batched=True)
test_tokenized = test_ds.map(tokenize, batched=True)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [7]:
#定义模型
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
    checkpoint,
    num_labels=2
)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
#训练配置 & 训练
from transformers import TrainingArguments, Trainer
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import os
os.environ["WANDB_DISABLED"] = "true"
def compute_metrics(pred):
    logits, labels = pred
    preds = np.argmax(logits, axis=-1)
    return {
        'accuracy': accuracy_score(labels, preds),
        'f1': f1_score(labels, preds)
    }

training_args = TrainingArguments(
    output_dir="./results",
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_steps=50,
    logging_dir="./logs"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tokenized,
    eval_dataset=test_tokenized,
    compute_metrics=compute_metrics
)

trainer.train()


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Step,Training Loss
50,0.1729
100,0.1927
150,0.0889
200,0.0767
250,0.119


TrainOutput(global_step=250, training_loss=0.13003800678253175, metrics={'train_runtime': 183.9592, 'train_samples_per_second': 21.744, 'train_steps_per_second': 1.359, 'total_flos': 526222110720000.0, 'train_loss': 0.13003800678253175, 'epoch': 2.0})

In [10]:
eval_results = trainer.evaluate()
eval_results


{'eval_loss': 0.41147688031196594,
 'eval_accuracy': 0.904,
 'eval_f1': 0.9047619047619048,
 'eval_runtime': 7.0977,
 'eval_samples_per_second': 70.446,
 'eval_steps_per_second': 4.509,
 'epoch': 2.0}

In [11]:
model.save_pretrained("./model")
tokenizer.save_pretrained("./model")


('./model/tokenizer_config.json',
 './model/special_tokens_map.json',
 './model/vocab.txt',
 './model/added_tokens.json',
 './model/tokenizer.json')

In [12]:
%%writefile inference.py
from fastapi import FastAPI
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch

# Load model
model_path = "./model"
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

app = FastAPI(title="IMDb Sentiment API")

@app.get("/")
def home():
    return {"message": "IMDb Sentiment Analysis API is running!"}

@app.post("/predict")
def predict(text: str):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        probabilities = torch.softmax(logits, dim=1)
        pred = torch.argmax(probabilities, dim=1).item()
        label = "Positive" if pred == 1 else "Negative"
        score = probabilities[0][pred].item()
    return {"label": label, "confidence": round(score, 4)}


Writing inference.py


In [None]:
!pip install fastapi uvicorn[standard]


In [16]:
!uvicorn inference:app --host 0.0.0.0 --port 8000 --reload


[32mINFO[0m:     Will watch for changes in these directories: ['/content']
[32mINFO[0m:     Uvicorn running on [1mhttp://0.0.0.0:8000[0m (Press CTRL+C to quit)
[32mINFO[0m:     Started reloader process [[36m[1m10985[0m] using [36m[1mWatchFiles[0m
2025-12-08 11:06:11.271374: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765191971.292202   10987 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765191971.298417   10987 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1765191971.314388   10987 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W

In [15]:
from google.colab.output import eval_js
eval_js("google.colab.kernel.proxyPort(8000)")


'https://8000-gpu-t4-s-23xocdquw1s32-a.europe-west4-2.prod.colab.dev'

In [18]:
%%writefile app.html
<!DOCTYPE html>
<html lang="en">
<head>
    <meta charset="UTF-8">
    <title>IMDb Sentiment Demo</title>
    <style>
        body { font-family: Arial; background: #f4f4f4; text-align:center; padding: 50px; }
        textarea { width: 60%; height: 120px; padding: 10px; font-size:16px; }
        button { padding: 10px 20px; font-size: 18px; margin-top: 10px; cursor: pointer; }
        #result { margin-top:20px; font-size:20px; font-weight:bold; }
    </style>
</head>
<body>

<h2>🎬 IMDb Sentiment Analysis</h2>

<textarea id="inputText" placeholder="Type your movie review here..."></textarea><br>
<button onclick="predict()">Analyze Sentiment</button>

<div id="result"></div>

<script>
async function predict() {
    let text = document.getElementById("inputText").value;
    let response = await fetch("/predict?text=" + encodeURIComponent(text), {
        method: "POST"
    });
    let data = await response.json();
    let result = document.getElementById("result");
    result.innerHTML = `Result: <span style="color:${data.label == 'Positive' ? 'green':'red'}">${data.label}</span><br>
                        Confidence: ${(data.confidence * 100).toFixed(2)}%`;
}
</script>

</body>
</html>


Overwriting app.html


In [None]:
!uvicorn inference:app --host 0.0.0.0 --port 8000 --reload


[32mINFO[0m:     Will watch for changes in these directories: ['/content']
[32mINFO[0m:     Uvicorn running on [1mhttp://0.0.0.0:8000[0m (Press CTRL+C to quit)
[32mINFO[0m:     Started reloader process [[36m[1m24820[0m] using [36m[1mWatchFiles[0m
2025-12-08 12:03:04.028483: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765195384.062128   24826 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765195384.073223   24826 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1765195384.106621   24826 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W