In [1]:
!pip install fastapi uvicorn transformers torch

Collecting fastapi
  Downloading fastapi-0.115.4-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn
  Downloading uvicorn-0.32.0-py3-none-any.whl.metadata (6.6 kB)
Collecting starlette<0.42.0,>=0.40.0 (from fastapi)
  Downloading starlette-0.41.2-py3-none-any.whl.metadata (6.0 kB)
Downloading fastapi-0.115.4-py3-none-any.whl (94 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.7/94.7 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading uvicorn-0.32.0-py3-none-any.whl (63 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading starlette-0.41.2-py3-none-any.whl (73 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m73.3/73.3 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: uvicorn, starlette, fastapi
Successfully installed fastapi-0.115.4 starlette-0.41.2 uvicorn-0.32.0


In [2]:
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import re

# Define FastAPI app
app = FastAPI()

# Load the model and tokenizer
model_name = "ag4sh1/Translate4Good"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define the acronym glossary
ACRONYM_GLOSSARY = {
    "UN": "ONU",
    "UNESCO": "UNESCO",
    "UNICEF": "UNICEF",
    "UNDP": "PNUD",
    "UNHCR": "ACNUR",
    "WHO": "OMS",
    "FAO": "FAO",
    "ILO": "OIT",
    "IMF": "FMI",
    "WTO": "OMC",
    # Add more acronyms as needed
}

# Request body model
class TranslationRequest(BaseModel):
    text: str

# Preprocess the input text by replacing acronyms with placeholders
def preprocess_text(text):
    placeholder_map = {}
    for acronym in ACRONYM_GLOSSARY.keys():
        # Use a unique placeholder unlikely to appear in normal text
        placeholder = f"__{acronym}__"
        # Replace whole word matches of the acronym
        text = re.sub(rf'\b{acronym}\b', placeholder, text)
        placeholder_map[placeholder] = acronym
    return text, placeholder_map

# Postprocess the translated text by replacing placeholders with correct translations
def postprocess_text(translated_text, placeholder_map):
    for placeholder, acronym in placeholder_map.items():
        translated_acronym = ACRONYM_GLOSSARY[acronym]
        translated_text = translated_text.replace(placeholder, translated_acronym)
    return translated_text

# Translation function
def translate_text(text, max_length=100):
    inputs = tokenizer(text, return_tensors="pt", truncation=True).to(device)
    with torch.no_grad():
        outputs = model.generate(inputs["input_ids"], max_length=max_length, num_beams=4, early_stopping=True)
    translation = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return translation

# Enhanced function to match capitalization and punctuation
def match_format(original_text, translated_text):
    # Tokenize both original and translated texts by words and punctuation
    original_tokens = re.findall(r'\w+|[^\w\s]', original_text, re.UNICODE)
    translated_tokens = re.findall(r'\w+|[^\w\s]', translated_text, re.UNICODE)

    formatted_tokens = []
    translated_index = 0

    for orig_token in original_tokens:
        if re.match(r'\W', orig_token):  # Punctuation token
            # Use original punctuation exactly as it is
            formatted_tokens.append(orig_token)
        else:
            if translated_index < len(translated_tokens):
                trans_token = translated_tokens[translated_index]

                # Match capitalization pattern
                if orig_token.isupper():
                    trans_token = trans_token.upper()
                elif orig_token[0].isupper():
                    trans_token = trans_token.capitalize()
                else:
                    trans_token = trans_token.lower()

                formatted_tokens.append(trans_token)
                translated_index += 1

    # Join tokens without adding extra spaces around punctuation
    formatted_text = ""
    for i, token in enumerate(formatted_tokens):
        if i > 0 and not re.match(r'\W', token) and not re.match(r'\W', formatted_tokens[i - 1]):
            formatted_text += " "  # Add space only between words
        formatted_text += token

    return formatted_text

# Define the translation endpoint with preprocessing and postprocessing
@app.post("/translate/")
async def translate(request: TranslationRequest):
    try:
        # Preprocess the input text
        preprocessed_text, placeholder_map = preprocess_text(request.text)
        # Generate translation
        raw_translation = translate_text(preprocessed_text)
        # Postprocess the translated text
        final_translation = postprocess_text(raw_translation, placeholder_map)
        # Format translation to match original text's punctuation and casing
        formatted_translation = match_format(request.text, final_translation)
        return {"translation": formatted_translation}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


tokenizer_config.json:   0%|          | 0.00/818 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/802k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/826k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.72M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/74.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.49k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/310M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/288 [00:00<?, ?B/s]