In [1]:
import time

import dash
from dash import html, dcc
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
import torch
from transformers import pipeline, MBartForConditionalGeneration, MBart50TokenizerFast
import nltk

nltk.download("punkt")
from nltk.tokenize import sent_tokenize

# Choose device and load model
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\P\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Device: cpu


In [2]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

In [3]:
# Create app
external_stylesheets = [dbc.themes.BOOTSTRAP,'https://fonts.googleapis.com/css2?family=Poetsen+One&display=swap', 'styles.css']
app = dash.Dash(__name__, external_stylesheets=external_stylesheets)
server = app.server

# Define Layout
app.layout = dbc.Container(
    fluid=True,
    children=[
        html.H1("Text Summarization & Translation", style={'color':'red','fontFamily': 'Poetsen One'}),
        html.Hr(),
        dbc.Spinner(
            dbc.Row(
                [
                    dbc.Col(dbc.Button("RUN", id="button-run", color="warning"), width=2),
                    dbc.Col(
                        html.Div(id="time-output", style={"margin-top": "8px"}),
                        width=10,
                    ),
                ],
                style={"margin-bottom": "15px"},
            )
        ),
        dbc.Row(  
            [
                dbc.Col(  
                    [
                        dbc.Label("Maximum Summary Length (30-150)", width=15),
                        dbc.Input(id="max_length", type="number", value=130, min=30, max=150, style={"width": "25%"}),
                    ],
                    width=15,
                ),
            ],
            style={"margin-bottom": "15px"},
        ),
        dbc.Row(
            [
                dbc.Col(
                    [
                        dbc.Label("Source Language:", width=15),
                        dbc.InputGroup(
                            [
                                dbc.Select(
                                    id="source-language",
                                    options=[
                                        {"label": "English", "value": "en_XX"},
                                    ],
                                    value="en_XX", 
                                ),
                            ]
                        ),
                        dbc.Textarea(
                            id="source-text",
                            style={"margin-top": "15px", "height": "45vh"},
                        ),
                    ]
                ),
                dbc.Col(
                    [   
                        dbc.Label("Target Language:", width=15),
                        dbc.InputGroup(
                            [
                                dbc.Select(
                                    id="target-language",
                                    options=[
                                        {"label": "Arabic", "value": "ar_AR"},
                                        {"label": "Czech", "value": "cs_CZ"},
                                        {"label": "German", "value": "de_DE"},
                                        {"label": "English", "value": "en_XX"},  
                                        {"label": "Spanish", "value": "es_XX"}, 
                                        {"label": "Estonian", "value": "et_EE"},
                                        {"label": "Finnish", "value": "fi_FI"},
                                        {"label": "French", "value": "fr_XX"}, 
                                        {"label": "Gujarati", "value": "gu_IN"},
                                        {"label": "Hindi", "value": "hi_IN"},
                                        {"label": "Italian", "value": "it_IT"},
                                        {"label": "Japanese", "value": "ja_XX"},  
                                        {"label": "Kazakh", "value": "kk_KZ"},
                                        {"label": "Korean", "value": "ko_KR"},
                                        {"label": "Lithuanian", "value": "lt_LT"},
                                        {"label": "Latvian", "value": "lv_LV"},
                                        {"label": "Burmese", "value": "my_MM"},
                                        {"label": "Nepali", "value": "ne_NP"},
                                        {"label": "Dutch", "value": "nl_XX"},  
                                        {"label": "Romanian", "value": "ro_RO"},
                                        {"label": "Russian", "value": "ru_RU"},
                                        {"label": "Sinhala", "value": "si_LK"},
                                        {"label": "Turkish", "value": "tr_TR"},
                                        {"label": "Vietnamese", "value": "vi_VN"},
                                        {"label": "Chinese", "value": "zh_CN"},
                                        {"label": "Afrikaans", "value": "af_ZA"},
                                        {"label": "Azerbaijani", "value": "az_AZ"},
                                        {"label": "Bengali", "value": "bn_IN"},
                                        {"label": "Persian", "value": "fa_IR"},
                                        {"label": "Hebrew", "value": "he_IL"},
                                        {"label": "Croatian", "value": "hr_HR"},
                                        {"label": "Indonesian", "value": "id_ID"},
                                        {"label": "Georgian", "value": "ka_GE"},
                                        {"label": "Khmer", "value": "km_KH"},
                                        {"label": "Macedonian", "value": "mk_MK"},
                                        {"label": "Malayalam", "value": "ml_IN"},
                                        {"label": "Mongolian", "value": "mn_MN"},
                                        {"label": "Marathi", "value": "mr_IN"},
                                        {"label": "Polish", "value": "pl_PL"},
                                        {"label": "Pashto", "value": "ps_AF"},
                                        {"label": "Portuguese", "value": "pt_XX"},
                                        {"label": "Swedish", "value": "sv_SE"},
                                        {"label": "Swahili", "value": "sw_KE"},
                                        {"label": "Tamil", "value": "ta_IN"},
                                        {"label": "Telugu", "value": "te_IN"},
                                        {"label": "Thai", "value": "th_TH"},
                                        {"label": "Tagalog", "value": "tl_XX"},
                                        {"label": "Ukrainian", "value": "uk_UA"},
                                        {"label": "Urdu", "value": "ur_PK"},
                                        {"label": "Xhosa", "value": "xh_ZA"},
                                        {"label": "Galician", "value": "gl_ES"},
                                        {"label": "Slovene", "value": "sl_SI"},
                                    ],
                                    value="th_TH",  
                                ),
                            ]
                        ),
                        dbc.Textarea(
                            id="target-text",
                            style={"margin-top": "15px", "height": "45vh"},
                        ),
                    ]
                ),
            ]
        ),
    ],
)

@app.callback(
    [Output("target-text", "value"), Output("time-output", "children")],
    [
        Input("button-run", "n_clicks"),
        Input("source-language", "value"),
        Input("target-language", "value"),
        Input("max_length", "value"),   
    ],
    [State("source-text", "value")],
)
def translate(n_clicks, src_lang, tgt_lang, max_set, src_text):
    if src_text is None or src_text == "":
        return "", "Did not run."

    t0 = time.time()

    min_set = round(max_set*0.25)

    # Text Summarization
    summary_text = summarizer(src_text, max_length=max_set, min_length=min_set, do_sample=False)
    summary_text = summary_text[0]["summary_text"]
    if tgt_lang != 'en_XX':
        tokenizer.src_lang = src_lang
        encoded_hi = tokenizer(summary_text, return_tensors="pt")
        generated_tokens = model.generate(
            **encoded_hi,
            forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]
        )
        tgt_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
    else:
        tgt_text = summary_text
    print("text_in = ",tgt_text)


    t1 = time.time()
    time_output = f"Translated on {device} in {t1-t0:.2f}s"

    return " ".join(tgt_text), time_output


if __name__ == "__main__":
    app.run_server(debug=True,port=4050)