Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions e2e_playwright/bokeh_chart_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import pandas as pd
import streamlit as st
from bokeh.plotting import figure
from chart_types import CHART_TYPES
import numpy as np
import pandas as pd

from streamlit_bokeh import streamlit_bokeh

np.random.seed(0)
Expand Down Expand Up @@ -188,9 +189,9 @@ def lorenz(xyz, t):
line_width=1.5,
)
elif chart == "linear_cmap":
from numpy.random import standard_normal
from bokeh.transform import linear_cmap
from bokeh.util.hex import hexbin
from numpy.random import standard_normal

x = standard_normal(50000)
y = standard_normal(50000)
Expand Down Expand Up @@ -279,6 +280,6 @@ def lorenz(xyz, t):
p.legend.location = "top_left"
p.legend.orientation = "horizontal"

streamlit_bokeh(p, use_container_width=False)
streamlit_bokeh(p, use_container_width=False, key="chart_1")

streamlit_bokeh(p, use_container_width=True)
streamlit_bokeh(p, use_container_width=True, key="chart_2")
129 changes: 88 additions & 41 deletions streamlit_bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,53 +12,86 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.metadata
import json
import os
import re
from typing import TYPE_CHECKING
import streamlit.components.v1 as components
import importlib.metadata

import bokeh
import streamlit as st
from bokeh.embed import json_item

if TYPE_CHECKING:
from bokeh.plotting.figure import Figure


# Create a _RELEASE constant. We'll set this to False while we're developing
# the component, and True when we're ready to package and distribute it.
# (This is, of course, optional - there are innumerable ways to manage your
# release process.)
_DEV = os.environ.get("DEV", False)
_RELEASE = not _DEV

# Declare a Streamlit component. `declare_component` returns a function
# that is used to create instances of the component. We're naming this
# function "_component_func", with an underscore prefix, because we don't want
# to expose it directly to users. Instead, we will create a custom wrapper
# function, below, that will serve as our component's public API.

# It's worth noting that this call to `declare_component` is the
# *only thing* you need to do to create the binding between Streamlit and
# your component frontend. Everything else we do in this file is simply a
# best practice.

if not _RELEASE:
_component_func = components.declare_component(
# We give the component a simple, descriptive name ("streamlit_bokeh"
# does not fit this bill, so please choose something better for your
# own component :)
"streamlit_bokeh",
# Pass `url` here to tell Streamlit that the component will be served
# by the local dev server that you run via `npm run start`.
# (This is useful while your component is in development.)
url="http://localhost:3001",

def _version_ge(a: str, b: str) -> bool:
"""
Return True if version string a is greater than or equal to b.

The comparison extracts up to three numeric components from each version
string (major, minor, patch) and compares them as integer tuples.
Non-numeric suffixes (for example, 'rc1', 'dev') are ignored.

Parameters
----------
a : str
The left-hand version string.
b : str
The right-hand version string to compare against.

Returns
-------
bool
True if a >= b, otherwise False.
"""

def parse(v: str) -> tuple[int, int, int]:
nums = [int(x) for x in re.findall(r"\d+", v)[:3]]
while len(nums) < 3:
nums.append(0)
return nums[0], nums[1], nums[2]

return parse(a) >= parse(b)


_STREAMLIT_VERSION = importlib.metadata.version("streamlit")

# If streamlit version is >= 1.51.0 use Custom Component v2 API, otherwise use
# Custom Component v1 API
# _IS_USING_CCV2 = _version_ge(_STREAMLIT_VERSION, "1.51.0")
# Temporarily setting this to False, will be updated in next PR.
_IS_USING_CCV2 = False

# Version-gated component registration
if _IS_USING_CCV2:
_component_func = st.components.v2.component(
"streamlit-bokeh.streamlit_bokeh",
js="v2/index-*.mjs",
html="<div class='stBokehContainer'></div>",
)
else:
# When we're distributing a production version of the component, we'll
# replace the `url` param with `path`, and point it to the component's
# build directory:
parent_dir = os.path.dirname(os.path.abspath(__file__))
build_dir = os.path.join(parent_dir, "frontend/build")
_component_func = components.declare_component("streamlit_bokeh", path=build_dir)
if not _RELEASE:
_component_func = st.components.v1.declare_component(
"streamlit_bokeh",
url="http://localhost:3001",
)
else:
parent_dir = os.path.dirname(os.path.abspath(__file__))
build_dir = os.path.join(parent_dir, "frontend/build")
_component_func = st.components.v1.declare_component(
"streamlit_bokeh", path=build_dir
)

if TYPE_CHECKING:
from bokeh.plotting.figure import Figure

__version__ = importlib.metadata.version("streamlit_bokeh")
REQUIRED_BOKEH_VERSION = "3.8.0"
Expand Down Expand Up @@ -112,14 +145,28 @@ def streamlit_bokeh(
f"{REQUIRED_BOKEH_VERSION}` to install the correct version."
)

# Call through to our private component function. Arguments we pass here
# will be sent to the frontend, where they'll be available in an "args"
# dictionary.
_component_func(
figure=json.dumps(json_item(figure)),
use_container_width=use_container_width,
bokeh_theme=theme,
key=key,
)
if _IS_USING_CCV2:
# Call through to our private component function.
_component_func(
key=key,
data={
"figure": json.dumps(json_item(figure)),
"bokeh_theme": theme,
"use_container_width": use_container_width,
},
isolate_styles=False,
)

return None
else:
# Call through to our private component function. Arguments we pass here
# will be sent to the frontend, where they'll be available in an "args"
# dictionary.
_component_func(
figure=json.dumps(json_item(figure)),
use_container_width=use_container_width,
bokeh_theme=theme,
key=key,
)

return None
return None
140 changes: 140 additions & 0 deletions streamlit_bokeh/frontend/src/v2/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/**
* Copyright (c) Snowflake Inc. (2025)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { beforeEach, describe, expect, test } from "vitest"

import {
getChartDataGenerator,
getChartDimensions,
setChartThemeGenerator,
} from "./index"
import { MinimalStreamlitTheme } from "./streamlit-theme"

describe("getChartDataGenerator", () => {
let getChartData: (figure: string) => {
data: object | null
hasChanged: boolean
}

beforeEach(() => {
getChartData = getChartDataGenerator()
})

test("should return parsed data and hasChanged true on first call", () => {
const figure = JSON.stringify({ key: "value" })
const result = getChartData(figure)

expect(result).toEqual({ data: { key: "value" }, hasChanged: true })
})

test("should return hasChanged false for the same figure", () => {
const figure = JSON.stringify({ key: "value" })
getChartData(figure)
const result = getChartData(figure)

expect(result).toEqual({ data: { key: "value" }, hasChanged: false })
})

test("should return hasChanged true for a different figure", () => {
getChartData(JSON.stringify({ key: "value" }))
const newFigure = JSON.stringify({ key: "newValue" })
const result = getChartData(newFigure)

expect(result).toEqual({ data: { key: "newValue" }, hasChanged: true })
})
})

// Unit tests for setChartThemeGenerator
describe("setChartThemeGenerator", () => {
let setChartTheme: (
newTheme: string,
newAppTheme: MinimalStreamlitTheme
) => boolean

beforeEach(() => {
setChartTheme = setChartThemeGenerator()
})

test("should apply the theme when theme changes", () => {
const newTheme = "dark"
const newAppTheme: MinimalStreamlitTheme = {
textColor: "white",
backgroundColor: "black",
secondaryBackgroundColor: "gray",
font: "Source Pro",
}
const result = setChartTheme(newTheme, newAppTheme)
const { use_theme: useTheme } =
global.window.Bokeh.require("core/properties")

expect(result).toBe(true)
expect(useTheme).toHaveBeenCalled()
})

test("should not reapply the theme if it's the same", () => {
const newTheme = "dark"
const newAppTheme: MinimalStreamlitTheme = {
textColor: "white",
backgroundColor: "black",
secondaryBackgroundColor: "gray",
font: "Source Pro",
}
setChartTheme(newTheme, newAppTheme)
const result = setChartTheme(newTheme, newAppTheme)

expect(result).toBe(false)
})

test("should apply Streamlit theme when appropriate", () => {
const newTheme = "streamlit"
const newAppTheme: MinimalStreamlitTheme = {
textColor: "white",
backgroundColor: "black",
secondaryBackgroundColor: "gray",
font: "Source Pro",
}
const result = setChartTheme(newTheme, newAppTheme)

expect(result).toBe(true)
})
})

describe("getChartDimensions", () => {
test("should return default dimensions when no width/height attributes are provided", () => {
const plot = { attributes: {} }
const result = getChartDimensions(plot, false, document.documentElement)
expect(result).toEqual({ width: 400, height: 350 })
})

test("should return provided dimensions when width/height attributes are set", () => {
const plot = { attributes: { width: 800, height: 400 } }
const result = getChartDimensions(plot, false, document.documentElement)
expect(result).toEqual({ width: 800, height: 400 })
})

test("should calculate new dimensions based on container width", () => {
Object.defineProperty(document.documentElement, "clientWidth", {
configurable: true,
writable: true,
value: 1200, // Set the desired value
})

const plot = { attributes: { width: 800, height: 400 } }
const result = getChartDimensions(plot, true, document.documentElement)
expect(result.width).toBe(1200)
expect(result.height).toBeCloseTo(600)
})
})
Loading
Loading