To start this Jupyter Dash app, please run all the cells below. Then, click on the **temporary** URL at the end of the last cell to open the app.

In [None]:
!pip install -q jupyter-dash==0.3.0rc1 dash-bootstrap-components transformers

[K     |████████████████████████████████| 51kB 4.5MB/s 
[K     |████████████████████████████████| 194kB 11.6MB/s 
[K     |████████████████████████████████| 1.5MB 16.2MB/s 
[K     |████████████████████████████████| 81kB 11.0MB/s 
[K     |████████████████████████████████| 2.9MB 33.0MB/s 
[K     |████████████████████████████████| 890kB 56.1MB/s 
[K     |████████████████████████████████| 1.0MB 50.1MB/s 
[K     |████████████████████████████████| 3.5MB 52.8MB/s 
[K     |████████████████████████████████| 194kB 54.2MB/s 
[K     |████████████████████████████████| 1.8MB 50.8MB/s 
[K     |████████████████████████████████| 358kB 49.8MB/s 
[?25h  Building wheel for dash (setup.py) ... [?25l[?25hdone
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Building wheel for dash-renderer (setup.py) ... [?25l[?25hdone
  Building wheel for dash-core-components (setup.py) ... [?25l[?25hdone
  Building wheel for dash-html-components (setup.py) ... [?25l[?25hdone
  Building 

In [None]:
import time

import dash
import dash_html_components as html
import dash_core_components as dcc
import dash_bootstrap_components as dbc
from dash.dependencies import Input, Output, State
from jupyter_dash import JupyterDash
from transformers import BartTokenizer, BartForConditionalGeneration
import torch

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

# Load Model
pretrained = "sshleifer/distilbart-xsum-12-6"
model = BartForConditionalGeneration.from_pretrained(pretrained)
tokenizer = BartTokenizer.from_pretrained(pretrained)

# Switch to cuda, eval mode, and FP16 for faster inference
if device == "cuda":
    model = model.half()
model.to(device)
model.eval();

Device: cuda


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1434.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=611201041.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=898822.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=26.0, style=ProgressStyle(description_w…




In [None]:
# Define app
app = JupyterDash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
server = app.server

controls = dbc.Card(
    [
        dbc.FormGroup(
            [
                dbc.Label("Output Length (# Tokens)"),
                dcc.Slider(
                    id="max-length",
                    min=10,
                    max=50,
                    value=30,
                    marks={i: str(i) for i in range(10, 51, 10)},
                ),
            ]
        ),
        dbc.FormGroup(
            [
                dbc.Label("Beam Size"),
                dcc.Slider(
                    id="num-beams",
                    min=2,
                    max=6,
                    value=4,
                    marks={i: str(i) for i in [2, 4, 6]},
                ),
            ]
        ),
        dbc.FormGroup(
            [
                dbc.Spinner(
                    [
                        dbc.Button("Summarize", id="button-run"),
                        html.Div(id="time-taken"),
                    ]
                )
            ]
        ),
    ],
    body=True,
    style={"height": "275px"},
)


# Define Layout
app.layout = dbc.Container(
    fluid=True,
    children=[
        html.H1("Dash Automatic Summarization (with DistilBART)"),
        html.Hr(),
        dbc.Row(
            [
                dbc.Col(
                    width=5,
                    children=[
                        controls,
                        dbc.Card(
                            body=True,
                            children=[
                                dbc.FormGroup(
                                    [
                                        dbc.Label("Summarized Content"),
                                        dcc.Textarea(
                                            id="summarized-content",
                                            style={
                                                "width": "100%",
                                                "height": "calc(75vh - 275px)",
                                            },
                                        ),
                                    ]
                                )
                            ],
                        ),
                    ],
                ),
                dbc.Col(
                    width=7,
                    children=[
                        dbc.Card(
                            body=True,
                            children=[
                                dbc.FormGroup(
                                    [
                                        dbc.Label("Original Text (Paste here)"),
                                        dcc.Textarea(
                                            id="original-text",
                                            style={"width": "100%", "height": "75vh"},
                                        ),
                                    ]
                                )
                            ],
                        )
                    ],
                ),
            ]
        ),
    ],
)

In [None]:
@app.callback(
    [Output("summarized-content", "value"), Output("time-taken", "children")],
    [
        Input("button-run", "n_clicks"),
        Input("max-length", "value"),
        Input("num-beams", "value"),
    ],
    [State("original-text", "value")],
)
def summarize(n_clicks, max_len, num_beams, original_text):
    if original_text is None or original_text == "":
        return "", "Did not run"

    t0 = time.time()

    inputs = tokenizer.batch_encode_plus(
        [original_text], max_length=1024, return_tensors="pt"
    )
    inputs = inputs.to(device)

    # Generate Summary
    summary_ids = model.generate(
        inputs["input_ids"],
        num_beams=num_beams,
        max_length=max_len,
        early_stopping=True,
    )
    out = [
        tokenizer.decode(
            g, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )
        for g in summary_ids
    ]

    t1 = time.time()
    time_taken = f"Summarized on {device} in {t1-t0:.2f}s"

    return out[0], time_taken

Run the cell below to run your Jupyter Dash app. Click on the **temporary** URL to access the app.

In [None]:
app.run_server(mode='inline')

<IPython.core.display.Javascript object>