# üöÄ Sign Language Recognition - Model Serving on Colab

Notebook n√†y gi√∫p b·∫°n deploy model ƒë√£ trained l√™n Colab v√† expose ra public URL qua ngrok.

## 1. Setup & Install Dependencies

In [None]:
# Mount Google Drive ƒë·ªÉ l·∫•y model
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Install dependencies
!pip install fastapi uvicorn python-multipart pyngrok nest-asyncio -q
print("‚úÖ Dependencies installed!")

In [None]:
# ‚ö†Ô∏è QUAN TR·ªåNG: ƒêƒÉng k√Ω t√†i kho·∫£n ngrok mi·ªÖn ph√≠ t·∫°i https://ngrok.com
# Sau ƒë√≥ l·∫•y authtoken v√† paste v√†o ƒë√¢y
NGROK_AUTH_TOKEN = "YOUR_NGROK_AUTH_TOKEN"  # <-- THAY TOKEN C·ª¶A B·∫†N V√ÄO ƒê√ÇY

from pyngrok import ngrok
ngrok.set_auth_token(NGROK_AUTH_TOKEN)
print("‚úÖ Ngrok configured!")

## 2. Define Model Architecture (same as training)

In [None]:
import os
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import cv2
import math
from torchvision import models

import warnings
warnings.filterwarnings('ignore')

NUM_CLASSES = 100
TARGET_FRAMES = 16

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=64, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


class AttentionPooling(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(dim, dim // 4),
            nn.Tanh(),
            nn.Linear(dim // 4, 1)
        )

    def forward(self, x):
        attn_weights = self.attention(x)
        attn_weights = F.softmax(attn_weights, dim=1)
        pooled = torch.sum(attn_weights * x, dim=1)
        return pooled


class ConvNeXtTransformer(nn.Module):
    def __init__(self, num_classes=100, hidden_size=256, resnet_pretrained_weights=None):
        super().__init__()

        convnext = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.IMAGENET1K_V1)
        self.cnn = convnext.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.feature_dim = 768

        self.pos_encoder = PositionalEncoding(
            d_model=self.feature_dim,
            max_len=64,
            dropout=0.1
        )

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.feature_dim,
            nhead=8,
            dim_feedforward=self.feature_dim * 4,
            dropout=0.3,
            activation='gelu',
            batch_first=True,
            norm_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)

        self.attention_pool = AttentionPooling(self.feature_dim)

        self.fc = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Dropout(0.4),
            nn.Linear(self.feature_dim, num_classes)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.transformer.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

        for m in self.attention_pool.modules():
            if isinstance(m, nn.Linear):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        x = self.cnn(x)
        x = self.pool(x)
        x = x.view(B, T, self.feature_dim)

        x = self.pos_encoder(x)
        x = self.transformer(x)
        x = self.attention_pool(x)
        x = self.fc(x)
        return x

## 3. Load Model & Label Mapping

In [None]:
# ========== C·∫§U H√åNH ƒê∆Ø·ªúNG D·∫™N ==========
# Thay ƒë·ªïi c√°c ƒë∆∞·ªùng d·∫´n n√†y theo v·ªã tr√≠ file c·ªßa b·∫°n tr√™n Google Drive

MODEL_PATH = "/content/drive/MyDrive/OlympicAI/augmented_balanced_convnexttransformer_best_model.pth"
LABEL_MAPPING_PATH = "/content/drive/MyDrive/OlympicAI/dataset/label_mapping.pkl"

# =========================================

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"üñ•Ô∏è Using device: {DEVICE}")

In [None]:
# Load model
model = ConvNeXtTransformer(num_classes=NUM_CLASSES, hidden_size=256)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()
print(f"‚úÖ Model loaded from {MODEL_PATH}")

# Load label mapping
with open(LABEL_MAPPING_PATH, 'rb') as f:
    label_mapping = pickle.load(f)
idx_to_label = {v: k for k, v in label_mapping.items()}
print(f"‚úÖ Loaded {len(idx_to_label)} classes")

## 4. Video Preprocessing Functions

In [None]:
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]


def read_video_bytes(video_bytes):
    """Read video from bytes (uploaded file)"""
    import tempfile
    with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp:
        tmp.write(video_bytes)
        tmp_path = tmp.name
    
    cap = cv2.VideoCapture(tmp_path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    cap.release()
    os.unlink(tmp_path)  # Delete temp file
    
    if len(frames) == 0:
        raise ValueError("Could not read any frames from video")
    return torch.from_numpy(np.stack(frames, axis=0))


def downsample_frames(frames, target_frames=TARGET_FRAMES):
    """Sample target_frames from video"""
    total = frames.shape[0]
    if total >= target_frames:
        indices = torch.linspace(0, total - 1, target_frames).long()
    else:
        indices = torch.arange(total)
        pad = target_frames - total
        indices = torch.cat([indices, indices[-1].repeat(pad)])

    frames = frames[indices]

    # Resize to 224x224 if needed
    if frames.shape[1] != 224 or frames.shape[2] != 224:
        frames = frames.permute(0, 3, 1, 2).float()
        frames = F.interpolate(frames, size=(224, 224), mode='bilinear', align_corners=False)
        frames = frames.permute(0, 2, 3, 1).to(torch.uint8)

    return frames


def normalize_frames(frames):
    """Normalize to ImageNet mean/std"""
    frames = frames.float() / 255.0
    frames = frames.permute(0, 3, 1, 2)  # (T, H, W, C) -> (T, C, H, W)

    mean = torch.tensor(MEAN).view(1, 3, 1, 1)
    std = torch.tensor(STD).view(1, 3, 1, 1)
    frames = (frames - mean) / std

    return frames


def preprocess_video(video_bytes):
    """Full preprocessing pipeline"""
    frames = read_video_bytes(video_bytes)
    frames = downsample_frames(frames)
    frames = normalize_frames(frames)
    return frames.unsqueeze(0)  # Add batch dim: (1, T, C, H, W)

## 5. Create FastAPI Server

In [None]:
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
import uvicorn
import nest_asyncio

# Apply nest_asyncio to allow running in Colab
nest_asyncio.apply()

app = FastAPI(
    title="ü§ü Sign Language Recognition API",
    description="API for Vietnamese Sign Language Recognition using ConvNeXt-Transformer",
    version="1.0.0"
)

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


# Response models
class PredictionResponse(BaseModel):
    label: str
    confidence: float
    label_idx: int


class TopKPrediction(BaseModel):
    label: str
    confidence: float
    label_idx: int


class PredictionResponseTopK(BaseModel):
    predictions: List[TopKPrediction]


@app.get("/")
async def root():
    return {
        "message": "ü§ü Sign Language Recognition API is running!",
        "docs": "/docs",
        "endpoints": {
            "predict": "POST /predict",
            "predict_topk": "POST /predict/topk?k=5",
            "health": "GET /health"
        }
    }


@app.get("/health")
async def health_check():
    return {
        "status": "healthy",
        "model_loaded": True,
        "device": DEVICE,
        "num_classes": NUM_CLASSES
    }


@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
    """
    Upload a video file and get the predicted sign language label.
    
    Supported formats: mp4, avi, mov, mkv
    """
    # Validate file type
    if not file.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
        raise HTTPException(status_code=400, detail="Invalid file format. Use mp4, avi, mov, or mkv")
    
    try:
        # Read video bytes
        video_bytes = await file.read()
        
        # Preprocess
        frames = preprocess_video(video_bytes)
        frames = frames.to(DEVICE)
        
        # Inference
        with torch.no_grad():
            outputs = model(frames)
            probs = F.softmax(outputs, dim=1)
            confidence, predicted = probs.max(1)
            
            label_idx = predicted.item()
            label_name = idx_to_label[label_idx]
            conf = confidence.item()
        
        return PredictionResponse(
            label=label_name,
            confidence=round(conf, 4),
            label_idx=label_idx
        )
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing video: {str(e)}")


@app.post("/predict/topk", response_model=PredictionResponseTopK)
async def predict_topk(file: UploadFile = File(...), k: int = 5):
    """
    Upload a video file and get top-k predicted sign language labels.
    
    - **k**: Number of top predictions to return (default: 5)
    """
    if not file.filename.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
        raise HTTPException(status_code=400, detail="Invalid file format. Use mp4, avi, mov, or mkv")
    
    try:
        video_bytes = await file.read()
        frames = preprocess_video(video_bytes)
        frames = frames.to(DEVICE)
        
        with torch.no_grad():
            outputs = model(frames)
            probs = F.softmax(outputs, dim=1)
            top_probs, top_indices = torch.topk(probs, k=min(k, NUM_CLASSES), dim=1)
            
            predictions = []
            for prob, idx in zip(top_probs[0], top_indices[0]):
                predictions.append(TopKPrediction(
                    label=idx_to_label[idx.item()],
                    confidence=round(prob.item(), 4),
                    label_idx=idx.item()
                ))
        
        return PredictionResponseTopK(predictions=predictions)
    
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing video: {str(e)}")


@app.get("/labels")
async def get_labels():
    """Get list of all available sign language labels"""
    return {
        "total_classes": NUM_CLASSES,
        "labels": list(label_mapping.keys())
    }

## 6. Start Server with Ngrok Tunnel üöÄ

In [None]:
# Start ngrok tunnel
PORT = 8000
public_url = ngrok.connect(PORT)

print("="*60)
print("üéâ SERVER IS RUNNING!")
print("="*60)
print(f"\nüåê Public URL: {public_url}")
print(f"üìö API Docs:   {public_url}/docs")
print(f"üìä Health:     {public_url}/health")
print("\n" + "="*60)
print("üìù Usage Example (with curl):")
print(f'   curl -X POST "{public_url}/predict" -F "file=@your_video.mp4"')
print("="*60)
print("\n‚ö†Ô∏è  Keep this cell running! Press Ctrl+C to stop.")
print("\n")

In [None]:
# Run the server (this will block - keep running!)
uvicorn.run(app, host="0.0.0.0", port=PORT)

## üìå Alternative: Use localtunnel (if ngrok doesn't work)

N·∫øu ngrok kh√¥ng ho·∫°t ƒë·ªông, b·∫°n c√≥ th·ªÉ d√πng localtunnel:

In [None]:
# # Uncomment n·∫øu mu·ªën d√πng localtunnel thay v√¨ ngrok
# !npm install -g localtunnel
# 
# # Ch·∫°y server trong background
# import subprocess
# subprocess.Popen(["python", "-c", f"""
# import uvicorn
# uvicorn.run(app, host='0.0.0.0', port={PORT})
# """])
# 
# # T·∫°o tunnel
# !lt --port 8000

## üß™ Test API (Optional)

Ch·∫°y cell n√†y t·ª´ m·ªôt notebook kh√°c ho·∫∑c terminal ƒë·ªÉ test:

In [None]:
# # Test code - ch·∫°y sau khi server ƒë√£ start
# import requests
# 
# API_URL = "YOUR_NGROK_URL"  # Paste URL t·ª´ output ·ªü tr√™n
# 
# # Health check
# response = requests.get(f"{API_URL}/health")
# print("Health:", response.json())
# 
# # Predict
# with open("test_video.mp4", "rb") as f:
#     response = requests.post(
#         f"{API_URL}/predict",
#         files={"file": f}
#     )
# print("Prediction:", response.json())