Skip to content

Commit

Permalink
feat: fixed dataframe generation & categories
Browse files Browse the repository at this point in the history
  • Loading branch information
kirangadhave committed Nov 8, 2023
1 parent fe47576 commit 031e546
Show file tree
Hide file tree
Showing 57 changed files with 2,044 additions and 2,007 deletions.
79 changes: 13 additions & 66 deletions examples/test_ext_widget.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,104 +11,51 @@
"source": [
"##### import pandas as pd\n",
"import persist_ext as PR\n",
"from persist_ext import plot\n",
"import altair as alt\n",
"from vega_datasets import data\n",
"import pandas as pd\n",
"\n",
"PR.dev.DEV = True"
"# PR.dev.DEV = True"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "7ea56cd8-7c40-45a8-b448-3b08857d8283",
"execution_count": 2,
"id": "18ce82e2-afa1-4b98-88da-02057f0a50ef",
"metadata": {
"__GENERATED_DATAFRAMES__": "ᯢ粠 ",
"trrack_graph": "ᯡ࠽䈌ʀ匦㢠⹰Ҁ㬠ᰤ烅͑瀶䀳V⤠ᤨˣᢡಥャÌ✕ၑ䋶⡵暷䥈kᑄᄉ妐ࢷዋ⃩䗛᙮ẜʦ墠ǁ#〨ㅣ࢒′䈤7㺰ᠠ湐戎Ⴆ寐痠˴リ$疒‣㐡㨦䎇㦲ʠ既๕壡‡犀⒚⥓ଭޯ⃅朳⠬Ö᡻ㆨ䐸昩ۍ媨⃓ࢱ̴ච᫼岄༏ѤƜܝ暠᧑䮶悮檩ǡ᪁㗖o#䖃栃̆璌Ⅴ渮➃גм⥊拢䋐૰⢤厊ᘭʊԤ憠佞娈ؽ憣惰㐬ൠ㧶ⷄ䒱䁲ᗐ播ౠ₭伂䥐ᱞѣᅈᨽࣦヸ๗௧焤慂‥㤤ℳ䏭ᎂ牜ㅨⓈ┕ವ剳㪁䠠 "
},
"outputs": [],
"source": [
"def get_chart(df):\n",
" selection = alt.selection_interval(name=\"selector\", encodings=[\"x\", \"y\"])\n",
" \n",
" \n",
" return alt.Chart(df).mark_point().encode(\n",
" x=\"Miles_per_Gallon\",\n",
" y=\"Weight_in_lbs\",\n",
" color=alt.condition(selection, \"Origin:O\", alt.value(\"gray\")),\n",
" opacity=alt.condition(selection, alt.value(0.7), alt.value(0.2))\n",
" ).add_params(selection)\n",
"\n",
"def get_df(df):\n",
" return df\n",
"\n",
"def get_vis(df = False):\n",
" cars = data.cars()\n",
" if df:\n",
" return get_df(cars)\n",
" return get_chart(cars)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "0c1cc488-2a38-4ba3-a911-79c96702a8ef",
"metadata": {
"__GENERATED_DATAFRAMES__": "ᯢ粠 ",
"trrack_graph": "ᯡ࠽䈌ʀ匦㢠⹰Ǡ↡䱐ᣪĸ΀୧䴡氹ࠥ䔠ظ䐩斬円ᠱ䣾㏐灠ጠ᦬ተ¶⡨⇱㌄ᅎ⥦庳ୖ份㸛ڨǁ#㍧䌰穠ኤРᜰ͐üメ湁˛㾄ᬠט憴)挄䀦栢琭ٮ㍄ⴡ嬤ᚣㆢ䀣擠䤜䙦¡Զ䙨ⴀᠢ沀㚈㄰ᨸሪ厎Ⴂ時∬磶‹᥺ࡋ帱ࠥ砹窺䊈氻Ȼ⩄ܨ洶₠氡Ḡ݅俏䜫廑ܵ㮚ᳱ曫䦉ㅵപA斀惰׉搬峮㕁ⴷၭ䍢僠ҏඉ孨䤸Պᄭ傱䐤୸渹ጣ䐁䑲ʭ⇰澕ᶥ愐籡ѭ玀wၤቛ峒ੁ❧ᑸ媦ւ煺ŗ℡٠ "
"__has_persist_output": true,
"trrack_graph": "ᯡ࠽䈌ʀ匦㢠⹰۠⬺Ҁᶠู倶䶠⚕Т惛֤ѐé䔂̸ῠ✠㍚ĺ炵扪2攩њᙜˤ撩⣷劂涽䯏Áဠ㑀粀๨ɰ䂠̇刦;䘸෤့垎堠媨᪠Ԅ牠Έƕ稦ܳሣή䠷⋑䅠Ȅ恨䰴❇ʥ,无恲㣱稠嬸ᖽѤݦҢ啺䑁榨ㄝὅ䀺晷⠒璦䈡㼚㻆潋ૡ۾ण䉔悊嘠传ς痻挪挍Ɫ䖽㜖㩳㈨ᱲ⮧ⓥ㎐傁㓘噚㝅ᜢ䞿⅃䫒姒⤘त䃂绖૽㒳扐≱㈹慃࠿擐屲اߡࣄ㤩◂ᣬᬣ䆠礜₸⎨Î罀⢕㧄瑼䖬⧈岊䁬桩ᴰ搠 "
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Displaying an interactive Vega-Altair Chart\n",
"{'Int64': ['__id_column', 'Cylinders', 'Weight_in_lbs', 'Year'], 'string': ['Name', '__annotations'], 'float64': ['Miles_per_Gallon', 'Displacement', 'Horsepower', 'Acceleration'], 'int64': ['__id_column', 'Cylinders', 'Weight_in_lbs', 'Year'], 'category': ['Origin'], 'bool': ['__selected', '__selected_intent']}\n",
"['Origin']\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "46bdf3d6cabd4940b5e6091b50c8f05d",
"model_id": "512f4240edf4456ab54ff328353191f2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"PersistWidget(data_values=[{'__id_column': 1, 'Name': 'chevrolet chevelle malibu', 'Miles_per_Gallon': 18.0, '"
"PersistWidget(data_values=[{'__id_column': '1', 'Name': 'chevrolet chevelle malibu', 'Miles_per_Gallon': 18.0,…"
]
},
"execution_count": 6,
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"c = get_vis()\n",
"PR.Persist(c)\n",
"# PR.Persist(data=get_vis(True))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "a1ab8430-7e1c-4c0e-a139-63f2c7be6d0d",
"metadata": {
"trrack_graph": "ᯡ࠽䈌ʀ匦㢠⹰ಀ㌠ᰡ〽䖡堠〫⏧㻥ロf戤瀪怩䛑ĸᥦ䁓ᬗ爩䚄䨠ዪन㪌㱥▐⥵傂曳宿Áᄠ㑀ŒĬ㒀൨ɰ䂠࿆̂;䘸ᷤၷ㠏␠妨᪠Ř岰ú}㺡䧬撨⟱⌥ಌば™ᠲᶿ཯⚂ᆁㆹӬ皝M䘦撊ᄢᦢ⇊㚂࡬戤㽦䶀ۗᔺϚ戁㻟㧇䒂䇐⏖≠灭᡺䶠ᏠČ徑℘恱帋ࡑ⯗凨溌䄠䥨稈墬㺌ㄱ䜹䔩⒂᧑നい䘬ԡ䍐瞆ࣣñ沏ඉ奩स᥌嬭䣱䐨୸溉ጫ䀰⑒७⇰䱶Υ挑籡Ἤ᎘⠠ᗰᘨ任㙎⏦ᒴ⅕⅄㔄皪䤳٠ "
},
"outputs": [],
"source": [
"cars = data.cars()\n",
"cars[\"Test\"] = None\n",
"cars.Test = cars.Test.astype('category')"
"a = PR.PersistTable(data.cars())\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ec318485-5519-4a61-8f19-f8b13e710069",
"id": "50bad02a-7eb6-4451-8099-9fc167915d9d",
"metadata": {
"trrack_graph": "ᯡ࠽䈌ʀ匦㢠⹰ˠᬡؤ〸悖ఠ㠥冀ঠᆮ扳ᔩŘ旆㢠ౙ䭁昸͐摁⒴E䨲࢔洘䕪⭵埌ⓥ巇㝾ކ桠øぱ䁙ဩ™+佨ఠ㜸プࡢ⳻姠Ɗႊ)洄䀦栢瘔ڮ卄᭞䜨ᆭㆢ䀣擠䡸㐦⪻ԥࡣ䱇栢沁嘸ᄰظሢ嶊Ⴀ晃䐔ᑶ‹०C娶ࠥ簈稺ɘ欤୤␮ཱིˎ堢㰠༓㜿మϙㄪᆓ塘ロ䡀煩ἣԠ૧怦㈰冨ᢘ㨺ᗦᜉ監Άᅾ࢐䉗ୠ悀ਸ਼ނ⵶䤢⌢☆ǂᑑ\"⭋䕄憐ȩ๢傈㩃೥じⸯƧ焤慜‥㠤䆝䷭ᆂ牼ㅠⓈ䔕ᒵ割‡䠠 "
"trrack_graph": "ᯡ࠽䈌ʀ匦㢠⹰ૠㆺ͐රʐ怶䀱墴ɒࠥ䀣ጻС撍ᑭ倹䕠瘴嘇䔼䴲⠠䭈⑁棒瀱ፉ⭾䦫ᬶ⻝紧Ҡǁ#〨㙃ࢊ′䈤7㺰ᠠ湐戎Ⴆ寐皠ˬリ$疒‣㐡㨦䎇㦲ڠ暢๕壡‡犀⒒㸾઻$撇ჰ䯨ؠ嬸ⷑ呤ಆҡ囹䑀妨䢣ㄭ䀺慼㈩瞤∡㸦ƶ晑䴶悡檩ǣ᪀痖o#䖃栃͡璌⃤渮➃גТ╌㓣ᎃ䥩Ҕ๖ᕑְ䲈憠炳ෘ氬๠ᇘʴ愇Ⳗ⤲⌡⑶ᇂᙑ¢᭛䵀悐簱੠凈㨣೥ㇸ䴿Bࡤ䒠଄ࢤ✯٥䟩㣅º匲ᨾಳ储⍀ "
},
"outputs": [],
"source": []
Expand Down
6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "persist_ext",
"version": "1.5.3",
"version": "1.5.4",
"description": "PersIst is a JupyterLab extension to enable persistent interactive visualizations in JupyterLab notebooks.",
"keywords": [
"jupyter",
Expand Down Expand Up @@ -101,7 +101,8 @@
"react-spring": "^9.7.2",
"react-vega": "^7.6.0",
"uuid": "^9.0.0",
"vega-lite": "^5.14.1"
"vega-lite": "^5.14.1",
"vega-tooltip": "^0.33.0"
},
"devDependencies": {
"@anywidget/vite": "^0.1.1",
Expand All @@ -112,7 +113,6 @@
"@types/d3": "^7.4.0",
"@types/json-schema": "^7.0.11",
"@types/lodash": "^4.14.197",
"@types/react": "^18.0.26",
"@types/react-addons-linked-state-mixin": "^0.14.22",
"@types/uuid": "^9.0.3",
"@typescript-eslint/eslint-plugin": "^6.1.0",
Expand Down
70 changes: 14 additions & 56 deletions persist_ext/internals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,62 +1,20 @@
import persist_ext.internals.data.get_generated_df as df
import altair as alt
from persist_ext.internals.data.idfy import ID_COLUMN
import persist_ext.internals.vis as vis
from persist_ext.internals.data.prepare import prepare
import persist_ext.internals.plot as plot
from persist_ext.internals.utils import dev
from persist_ext.internals.utils.logger import Out, logger
from persist_ext.internals.widgets.persist_output.widget import (
DEFAULT_DATA_ACCESSOR,
PersistWidget,
)
from persist_ext.internals.widgets.vegalite_chart.utils import (
pop_data_defs_from_charts_recursive,
from persist_ext.internals.widgets.persist_output.widget import PersistWidget
from persist_ext.internals.widgets.persist_output.wrappers import (
Persist,
PersistChart,
PersistTable,
)

dev.DEV = False


def Persist(
chart=None,
data=None,
df_name="persist_df",
id_column=ID_COLUMN,
data_accessor=DEFAULT_DATA_ACCESSOR,
):
if chart is None and data is None:
raise ValueError(
"Need a valid vega altair chart and/or dataframe to be provided."
)

# If visualizing charts
if chart is not None:
if data is None: # if data is not pass explicitly
chart_data = getattr(chart, "data", alt.Undefined)
if chart_data is alt.Undefined: # if chart does not have top level data
raise ValueError(
"""
Cannot infer dataset from vega altair specification. The data might be specified in subcharts.
Persist does not support such charts.
Please provide data at the top, or pass in the dataset explicitly as second arugment.
"""
)
chart_data = prepare(chart_data, id_column)
chart.data = chart_data
else: # if data is passed
chart = pop_data_defs_from_charts_recursive(chart, [])
chart.data = prepare(data, id_column)

print("Displaying an interactive Vega-Altair Chart")
return PersistWidget(
chart, df_name=df_name, id_column=id_column, data_accessor=data_accessor
)

if data is not None: # if only showing dataframe
data = prepare(data, id_column)
print("Displaying an interactive DataTable")
return PersistWidget(
data, df_name=df_name, id_column=id_column, data_accessor=data_accessor
)


__all__ = ["vis", "logger", "Out", "df", "prepare", "dev", "Persist"]
__all__ = [
"plot",
"dev",
"PersistWidget",
"Persist",
"PersistChart",
"PersistTable",
]
1 change: 1 addition & 0 deletions persist_ext/internals/data/generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def set(self, name: str, data: DataFrame, override):
raise KeyError(f"Already exists dataframe named '{name}'")

self.__dataframe_map[name] = data

return True

def has(self, name: str) -> bool:
Expand Down
1 change: 0 additions & 1 deletion persist_ext/internals/data/idfy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ def idfy_dataframe(df, id_column):
ids = df.index + 1
df.insert(0, id_column, ids)
df[id_column] = df[id_column].apply(str)

else:
if df[id_column].unique().size != df.shape[0]:
raise Exception(f"Column '{id_column}' already exists, but not unique")
Expand Down
12 changes: 11 additions & 1 deletion persist_ext/internals/data/process_generate_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,18 @@ def process_generate_dataset(df, keep_selection_columns=False, keep_id_col=False
if not keep_id_col:
cols_to_remove.append(ID_COLUMN)

def process_selection_column(data):
is_selected = data[SELECTED_COLUMN_BRUSH]
data = data.drop(columns=[SELECTED_COLUMN_BRUSH])

if is_selected.sum() > 0:
data["is_selected"] = is_selected

return data

df = process_selection_column(df)

if not keep_selection_columns:
cols_to_remove.append(SELECTED_COLUMN_BRUSH)
cols_to_remove.append(SELECTED_COLUMN_INTENT)

df.drop(columns=cols_to_remove, inplace=True)
Expand Down
5 changes: 5 additions & 0 deletions persist_ext/internals/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from persist_ext.internals.plot.barchart import barchart
from persist_ext.internals.plot.scatterplot import scatterplot


__all__ = ["scatterplot", "barchart"]
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
import altair as alt
from persist_ext.internals.data.idfy import ID_COLUMN

from persist_ext.internals.vis.plot_helpers import base_altair_plot
from persist_ext.internals.widgets.trrackable_output.output_with_trrack_widget import (
OutputWithTrrackWidget,
)
from persist_ext.internals.widgets.vegalite_chart.vegalite_chart_widget import (
VegaLiteChartWidget,
)
from persist_ext.internals.plot.plot_helpers import base_altair_plot
from persist_ext.internals.widgets.persist_output.wrappers import PersistChart


def barchart(
Expand All @@ -19,6 +15,7 @@ def barchart(
selection_type="interval",
height=400,
width=400,
id_column=ID_COLUMN,
):
"""
Args:
Expand All @@ -32,7 +29,9 @@ def barchart(
Returns:
altair chart object
"""
chart, data = base_altair_plot(data, height=height, width=width)
chart, data = base_altair_plot(
data, height=height, width=width, id_column=id_column
)

chart = chart.mark_bar()

Expand All @@ -51,23 +50,23 @@ def barchart(

encodings = [barchart_non_agg_axis]

x_encode = chart.encoding.x.to_dict()
y_encode = chart.encoding.y.to_dict()
x_encode = chart.encoding.x
y_encode = chart.encoding.y

is_binned = False
is_time_unit = False

if barchart_non_agg_axis == "x":
is_binned = "bin" in x_encode
is_time_unit = "timeUnit" in x_encode
is_ordinal_or_nominal = "type" in x_encode and x_encode["type"] in [
is_binned = hasattr(x_encode, "bin")
is_time_unit = hasattr(x_encode, "timeUnit")
is_ordinal_or_nominal = hasattr(x_encode, "type") and x_encode.type in [
"nominal",
"ordinal",
]
elif barchart_non_agg_axis == "y":
is_time_unit = "timeUnit" in y_encode
is_binned = "bin" in y_encode
is_ordinal_or_nominal = "type" in y_encode and y_encode["type"] in [
is_binned = hasattr(y_encode, "bin")
is_time_unit = hasattr(y_encode, "timeUnit")
is_ordinal_or_nominal = hasattr(y_encode, "type") and y_encode.type in [
"nominal",
"ordinal",
]
Expand Down Expand Up @@ -101,6 +100,4 @@ def barchart(
color=alt.condition(selection, alt.value("steelblue"), alt.value("gray"))
)

return OutputWithTrrackWidget(
body_widget=VegaLiteChartWidget(chart, data=data), data=data
)
return PersistChart(chart=chart, data=data)
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from persist_ext.internals.data.prepare import prepare


def base_altair_plot(data, height, width, *args, **kwargs):
data = prepare(data)
def base_altair_plot(data, height, width, id_column, *args, **kwargs):
data = prepare(data, id_column)

if data is False:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import altair as alt
from persist_ext.internals.data.idfy import ID_COLUMN

from persist_ext.internals.vis.plot_helpers import base_altair_plot
from persist_ext.internals.widgets.trrackable_output.output_with_trrack_widget import (
OutputWithTrrackWidget,
)
from persist_ext.internals.widgets.vegalite_chart.vegalite_chart_widget import (
VegaLiteChartWidget,
)
from persist_ext.internals.plot.plot_helpers import base_altair_plot

from persist_ext.internals.widgets.persist_output.wrappers import PersistChart


def scatterplot(
Expand All @@ -19,6 +16,7 @@ def scatterplot(
selection_type="interval",
height=400,
width=400,
id_column=ID_COLUMN,
):
"""
Args:
Expand All @@ -32,7 +30,9 @@ def scatterplot(
Returns:
altair chart object
"""
chart, data = base_altair_plot(data, height=height, width=width)
chart, data = base_altair_plot(
data, height=height, width=width, id_column=id_column
)

if circle:
chart = chart.mark_circle()
Expand Down Expand Up @@ -63,6 +63,4 @@ def scatterplot(
color=alt.condition(selection, alt.value("steelblue"), alt.value("gray"))
)

return OutputWithTrrackWidget(
body_widget=VegaLiteChartWidget(chart=chart, data=data), data=data
)
return PersistChart(chart=chart, data=data)

0 comments on commit 031e546

Please sign in to comment.