[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/sajitheranda/sign_language_web/blob/main/sign_backend/Signify_Backend.ipynb)

# Model

## Setup

In [None]:
!git clone --branch classifier https://github.com/chamodAchintha/Signify.git

In [None]:
%cd /content/Signify/Stochastic-Transformer-Networks

In [None]:
!pip install -r requirements.txt

## Load Models & Configs

In [None]:
%cd /content/Signify/Stochastic-Transformer-Networks

In [None]:
# Translation
!gdown 1UbS59BVb7tRzrbKDT64k3otoMRcYiWek -O /content/Signify/Stochastic-Transformer-Networks/SavedModels/translation/
!gdown 1dJ0AE_W9reoXrCkzGvmq3spKgEYMrMx2 -O /content/Signify/Stochastic-Transformer-Networks/SavedModels/translation/

# Classification
!gdown 1UzRJLMTWK9e7fYDUfkynnyPmzCA5EdHP -O /content/Signify/Stochastic-Transformer-Networks/SavedModels/classification/
!gdown 1W-PP3CPf5VVhQ54a0zB5LFbMZ3SS3eln -O /content/Signify/Stochastic-Transformer-Networks/SavedModels/classification/

In [None]:
import os
import logging
from sys import platform

import torch
import torch.nn.functional as F
from transformers import MBart50TokenizerFast
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

from signjoey.helpers import load_config, make_logger
from signjoey.classification_model import ClassificationModel
from signjoey.sinhala_sentence.translation_model import SinhalaSignTranslationModel
from signjoey.sinhala_sentence.search import greedy_decode


In [None]:
logger = logging.getLogger(__name__)
if not logger.handlers:
    logger.setLevel(level=logging.DEBUG)
    fh = logging.FileHandler("/content/inference.log")
    fh.setLevel(level=logging.DEBUG)
    logger.addHandler(fh)
    formatter = logging.Formatter("%(asctime)s %(message)s")
    fh.setFormatter(formatter)
    if platform == "linux":
        sh = logging.StreamHandler()
        sh.setLevel(logging.INFO)
        sh.setFormatter(formatter)
        logging.getLogger("").addHandler(sh)
    logger.info("Hello! This is Joey-NMT.")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Device: {device}")

## Translation

In [None]:
%cd /content/Signify/Stochastic-Transformer-Networks

In [None]:
translation_config_path = "/content/Signify/Stochastic-Transformer-Networks/SavedModels/translation/translation_config.yaml"
translation_config = load_config(translation_config_path)
t_train_config = translation_config['training']

tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50", tgt_lang=translation_config['data'].get('tgt_language', "si_LK"))

# Load model
translation_model = SinhalaSignTranslationModel(translation_config, logger)

translation_checkpoint_path = os.path.join(translation_config["training"]["model_dir"], 'best_model.pth')
translation_model.load_state_dict(torch.load(translation_checkpoint_path, map_location=device)['model_state_dict'])
translation_model.to(device)
translation_model.eval()

logger.info(f"Translation Model loaded for Inference from {translation_checkpoint_path}.")

In [None]:
def prepare_inference_translation_input(keypoints: torch.Tensor, encoder_seq_len: int = 40):
    frame_count = keypoints.size(0)

    if frame_count < encoder_seq_len:
        # Padding
        padding = torch.zeros(encoder_seq_len - frame_count, keypoints.size(1))
        padded_keypoints = torch.cat([keypoints, padding], dim=0)
        keypoints_mask = torch.cat([
            torch.ones(frame_count, dtype=torch.long),
            torch.zeros(encoder_seq_len - frame_count, dtype=torch.long)
        ])
    elif frame_count == encoder_seq_len:
        padded_keypoints = keypoints
        keypoints_mask = torch.ones(encoder_seq_len, dtype=torch.long)
    else:
        raise ValueError(f"frame_count ({frame_count}) > encoder_seq_len ({encoder_seq_len})")

    # Add batch dimension
    return {
        "keypoints": padded_keypoints.unsqueeze(0).float(),         # [1, seq_len, 204]
        "keypoints_mask": keypoints_mask.unsqueeze(0).long()        # [1, seq_len]
    }


In [None]:
def get_translation_prediction(keypoints):
    with torch.no_grad():
        keypoints = prepare_inference_translation_input(sample_input)['keypoints'].to(device)
        keypoints_mask = prepare_inference_translation_input(sample_input)['keypoints_mask'].to(device)

        # Perform greedy decoding
        encoder_output = translation_model.encode(keypoints, keypoints_mask)[0]
        decoded_sequences = greedy_decode(
            src_mask=keypoints_mask,
            bos_index=tokenizer.lang_code_to_id.get(translation_config['data']['tgt_language']),
            eos_index=tokenizer.eos_token_id,
            max_output_length=translation_config['data']['max_sent_length'],
            decoder=translation_model.decoder,
            encoder_output=encoder_output,
            device=device
        )

        return {'text' : tokenizer.batch_decode(decoded_sequences, skip_special_tokens=True)[0].strip()}


In [None]:
# Sample Test
sample_input = torch.rand([31, 204])
get_translation_prediction(sample_input)

## Classification

In [None]:
%cd /content/Signify/Stochastic-Transformer-Networks

In [None]:
classification_config_path = "/content/Signify/Stochastic-Transformer-Networks/SavedModels/classification/classification_config.yaml"
classification_config = load_config(classification_config_path)

# Load the checkpoint
classification_checkpoint_path = os.path.join(classification_config["training"]["model_dir"], 'best_model.pth')
classification_checkpoint = torch.load(classification_checkpoint_path, weights_only=False, map_location=device)

label_encoder = classification_checkpoint['label_encoder']

# Load model
classification_model = ClassificationModel(classification_config, logger)
classification_model.load_state_dict(classification_checkpoint['model_state_dict'])

classification_model.to(device)
classification_model.eval()

logger.info(f"Classification model loaded for inference from {classification_checkpoint_path}.")

In [None]:
def prepare_classification_input(keypoints: torch.Tensor):
    seq_length = classification_config['data']['seq_length']
    frame_count = keypoints.size(0)

    if frame_count < seq_length:
        padding = torch.zeros(seq_length - frame_count, keypoints.size(1))
        padded_keypoints = torch.cat([keypoints, padding], dim=0)
        mask = torch.cat([
            torch.ones(frame_count, dtype=torch.bool),
            torch.zeros(seq_length - frame_count, dtype=torch.bool)
        ])
    elif frame_count == seq_length:
        padded_keypoints = keypoints
        mask = torch.ones(seq_length, dtype=torch.bool)
    else:
        raise ValueError(f"frame_count ({frame_count}) > seq_length ({seq_length})")

    return {
        "keypoints": padded_keypoints.unsqueeze(0).float(),  # [1, seq_length, 204]
        "mask": mask.unsqueeze(0)                             # [1, seq_length]
    }


In [None]:
def get_classification_prediction(keypoints):

    with torch.no_grad():
        inputs = prepare_classification_input(keypoints)
        data = inputs['keypoints'].to(device)
        mask = inputs['mask'].unsqueeze(1).expand(-1, 1, -1).to(device)

        output = classification_model(data, mask)

        probs = F.softmax(output, dim=1)
        top_probs, top_indices = torch.topk(probs, k=3, dim=1)

        top_classes = label_encoder.inverse_transform(top_indices[0].cpu().numpy())

        # Numbered predictions
        return {
            i + 1: (gloss.item(), round(prob.item(), 4))
            for i, (gloss, prob) in enumerate(zip(top_classes, top_probs[0]))
        }


In [None]:
# Test classification
sample_input = torch.rand([31, 204])
get_classification_prediction(sample_input)

# Backend

In [None]:
!pip install fastapi uvicorn pyngrok nest-asyncio

In [None]:
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import nest_asyncio
from pyngrok import ngrok
import uvicorn

In [None]:
!ngrok authtoken 2xOLBPcSEHDtC7wUqTmy3xGiqwK_2zprxCpK5w1aj43KwpVvQ

In [None]:
class KeypointInput(BaseModel):
    keypoints: list[list[float]]  # Tensor as nested list: [frames, 204]

# ----- Start FastAPI app -----
app = FastAPI()

@app.post("/sign_classify")
def classify_sign(input_data: KeypointInput):
    keypoints_tensor = torch.tensor(input_data.keypoints, dtype=torch.float32)
    return get_classification_prediction(keypoints_tensor)

@app.post("/sign_translate")
def translate_sign(input_data: KeypointInput):
    keypoints_tensor = torch.tensor(input_data.keypoints, dtype=torch.float32)
    return get_translation_prediction(keypoints_tensor)

# ----- Start ngrok -----
ngrok_tunnel = ngrok.connect(8000)
print("Public URL:", ngrok_tunnel.public_url)

# Allow running uvicorn inside Colab
nest_asyncio.apply()
uvicorn.run(app, host="0.0.0.0", port=8000)
