# Pandas example

This example showcases hooking up an LLM to answer questions and generate Plotly Express plots over a Pandas DataFrame.

This uses the titanic survival dataset from HuggingFace, see https://huggingface.co/datasets/julien-c/titanic-survival. To download it, use the the HuggingFace `datasets` library. Alternatively, substitute the dataframe with any local dataset you may have.

In [4]:
from datasets import load_dataset

df = load_dataset("julien-c/titanic-survival")["train"].to_pandas()
df

Using custom data configuration julien-c--titanic-survival-1f1c84bb27ca5f78
Found cached dataset csv (/Users/zschillaci/.cache/huggingface/datasets/julien-c___csv/julien-c--titanic-survival-1f1c84bb27ca5f78/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
100%|██████████| 1/1 [00:00<00:00, 611.95it/s]


Unnamed: 0,Survived,Pclass,Name,Sex,Age,Siblings/Spouses Aboard,Parents/Children Aboard,Fare
0,0,3,Mr. Owen Harris Braund,male,22.0,1,0,7.2500
1,1,1,Mrs. John Bradley (Florence Briggs Thayer) Cum...,female,38.0,1,0,71.2833
2,1,3,Miss. Laina Heikkinen,female,26.0,0,0,7.9250
3,1,1,Mrs. Jacques Heath (Lily May Peel) Futrelle,female,35.0,1,0,53.1000
4,0,3,Mr. William Henry Allen,male,35.0,0,0,8.0500
...,...,...,...,...,...,...,...,...
882,0,2,Rev. Juozas Montvila,male,27.0,0,0,13.0000
883,1,1,Miss. Margaret Edith Graham,female,19.0,0,0,30.0000
884,0,3,Miss. Catherine Helen Johnston,female,7.0,1,2,23.4500
885,1,1,Mr. Karl Howell Behr,male,26.0,0,0,30.0000


In [5]:
from langchain import OpenAI, PandasDataFrameChain

In [6]:
llm = OpenAI(temperature=0)

In [9]:
df_chain = PandasDataFrameChain.from_llm(llm=llm, dataframe=df, verbose=True)

## Ask direct questions on the dataset
These are questions where the output is expected to be a single value (e.g. float, string, etc.).

In [10]:
output = df_chain("How many people survived?")
output



[1m> Entering new PandasDataFrameChain chain...[0m
Question: [36;1m[1;3mHow many people survived?[0m
Code: [33;1m[1;3mdf['Survived'].sum()[0m
Result: [32;1m[1;3m342[0m
[1m> Finished chain.[0m


{'query': 'How many people survived?',
 'code': ["df['Survived'].sum()"],
 'result': 342}

The chain returns both the generated code in the `code` field and the Python output from the code execution in the `result` field.

In [11]:
output = df_chain("How many people under 30 died?")
output



[1m> Entering new PandasDataFrameChain chain...[0m
Question: [36;1m[1;3mHow many people under 30 died?[0m
Code: [33;1m[1;3mdf[df['Age'] < 30]['Survived'].value_counts()[0][0m
Result: [32;1m[1;3m305[0m
[1m> Finished chain.[0m


{'query': 'How many people under 30 died?',
 'code': ["df[df['Age'] < 30]['Survived'].value_counts()[0]"],
 'result': 305}

In [13]:
output = df_chain("What was the average fare in 1st class?")
output



[1m> Entering new PandasDataFrameChain chain...[0m
Question: [36;1m[1;3mWhat was the average fare in 1st class?[0m
Code: [33;1m[1;3mdf[df['Pclass'] == 1]['Fare'].mean()[0m
Result: [32;1m[1;3m84.1546875[0m
[1m> Finished chain.[0m


{'query': 'What was the average fare in 1st class?',
 'code': ["df[df['Pclass'] == 1]['Fare'].mean()"],
 'result': 84.1546875}

## Filter or transform the dataset
These are operations which return a Pandas DataFrame or Series object after applying some filtering or transformation function.

In [14]:
output = df_chain("Remove duplicates")
output



[1m> Entering new PandasDataFrameChain chain...[0m
Question: [36;1m[1;3mRemove duplicates[0m
Code: [33;1m[1;3mdf.drop_duplicates()[0m
Result: [32;1m[1;3m     Survived  Pclass                                               Name  \
0           0       3                             Mr. Owen Harris Braund   
1           1       1  Mrs. John Bradley (Florence Briggs Thayer) Cum...   
2           1       3                              Miss. Laina Heikkinen   
3           1       1        Mrs. Jacques Heath (Lily May Peel) Futrelle   
4           0       3                            Mr. William Henry Allen   
..        ...     ...                                                ...   
882         0       2                               Rev. Juozas Montvila   
883         1       1                        Miss. Margaret Edith Graham   
884         0       3                     Miss. Catherine Helen Johnston   
885         1       1                               Mr. Karl Howell Behr   

{'query': 'Remove duplicates',
 'code': ['df.drop_duplicates()'],
 'result':      Survived  Pclass                                               Name  \
 0           0       3                             Mr. Owen Harris Braund   
 1           1       1  Mrs. John Bradley (Florence Briggs Thayer) Cum...   
 2           1       3                              Miss. Laina Heikkinen   
 3           1       1        Mrs. Jacques Heath (Lily May Peel) Futrelle   
 4           0       3                            Mr. William Henry Allen   
 ..        ...     ...                                                ...   
 882         0       2                               Rev. Juozas Montvila   
 883         1       1                        Miss. Margaret Edith Graham   
 884         0       3                     Miss. Catherine Helen Johnston   
 885         1       1                               Mr. Karl Howell Behr   
 886         0       3                                 Mr. Patrick Dooley   


In [15]:
output = df_chain("Average fare by class and gender")
output



[1m> Entering new PandasDataFrameChain chain...[0m
Question: [36;1m[1;3mAverage fare by class and gender[0m
Code: [33;1m[1;3mdf.groupby(['Pclass', 'Sex'])['Fare'].mean()[0m
Result: [32;1m[1;3mPclass  Sex   
1       female    106.125798
        male       67.226127
2       female     21.970121
        male       19.741782
3       female     16.118810
        male       12.695466
Name: Fare, dtype: float64[0m
[1m> Finished chain.[0m


{'query': 'Average fare by class and gender',
 'code': ["df.groupby(['Pclass', 'Sex'])['Fare'].mean()"],
 'result': Pclass  Sex   
 1       female    106.125798
         male       67.226127
 2       female     21.970121
         male       19.741782
 3       female     16.118810
         male       12.695466
 Name: Fare, dtype: float64}

In [17]:
output = df_chain("Remove men under the age of 30 and sort by fare")
output



[1m> Entering new PandasDataFrameChain chain...[0m
Question: [36;1m[1;3mRemove men under the age of 30 and sort by fare[0m
Code: [33;1m[1;3mdf[(df['Sex'] != 'male') | (df['Age'] >= 30)].sort_values(by='Fare')[0m
Result: [32;1m[1;3m     Survived  Pclass                               Name     Sex   Age  \
728         0       2                Mr. Robert J Knight    male  41.0   
594         0       3                 Mr. Alfred Johnson    male  49.0   
178         0       3                 Mr. Lionel Leonard    male  36.0   
630         0       1       Mr. William Henry Marsh Parr    male  30.0   
802         0       1              Mr. Thomas Jr Andrews    male  39.0   
..        ...     ...                                ...     ...   ...   
435         0       1                   Mr. Mark Fortune    male  64.0   
339         1       1      Miss. Alice Elizabeth Fortune  female  24.0   
257         1       1                    Miss. Anna Ward  female  35.0   
733         1    

{'query': 'Remove men under the age of 30 and sort by fare',
 'code': ["df[(df['Sex'] != 'male') | (df['Age'] >= 30)].sort_values(by='Fare')"],
 'result':      Survived  Pclass                               Name     Sex   Age  \
 728         0       2                Mr. Robert J Knight    male  41.0   
 594         0       3                 Mr. Alfred Johnson    male  49.0   
 178         0       3                 Mr. Lionel Leonard    male  36.0   
 630         0       1       Mr. William Henry Marsh Parr    male  30.0   
 802         0       1              Mr. Thomas Jr Andrews    male  39.0   
 ..        ...     ...                                ...     ...   ...   
 435         0       1                   Mr. Mark Fortune    male  64.0   
 339         1       1      Miss. Alice Elizabeth Fortune  female  24.0   
 257         1       1                    Miss. Anna Ward  female  35.0   
 733         1       1              Mr. Gustave J Lesurer    male  35.0   
 676         1       

## Directly generate Plotly figures
If you ask for a plot, the generated `df.plot` code will be automatically translated into the equivalent Plotly Express code.

In [19]:
output = df_chain("Plot the fare of people under 30 versus their age, colored by sex")
output["result"]



[1m> Entering new PandasDataFrameChain chain...[0m
Question: [36;1m[1;3mPlot the fare of people under 30 versus their age, colored by sex[0m
Code: [33;1m[1;3mpx.scatter(df.loc[df['Age'] < 30], x='Age', y='Fare', color='Sex', color_continuous_scale='viridis')[0m
Result: [32;1m[1;3mFigure({
    'data': [{'hovertemplate': 'Sex=male<br>Age=%{x}<br>Fare=%{y}<extra></extra>',
              'legendgroup': 'male',
              'marker': {'color': '#636efa', 'symbol': 'circle'},
              'mode': 'markers',
              'name': 'male',
              'orientation': 'v',
              'showlegend': True,
              'type': 'scatter',
              'x': array([22., 27.,  2., ..., 25., 27., 26.]),
              'xaxis': 'x',
              'y': array([ 7.25  ,  8.4583, 21.075 , ...,  7.05  , 13.    , 30.    ]),
              'yaxis': 'y'},
             {'hovertemplate': 'Sex=female<br>Age=%{x}<br>Fare=%{y}<extra></extra>',
              'legendgroup': 'female',
              'ma

In [20]:
output = df_chain("Plot the average fare per class")
output["result"]



[1m> Entering new PandasDataFrameChain chain...[0m
Question: [36;1m[1;3mPlot the average fare per class[0m
Code: [33;1m[1;3mpx.line(df.groupby('Pclass').mean().reset_index(), x='Pclass', y='Fare')[0m
Result: [32;1m[1;3mFigure({
    'data': [{'hovertemplate': 'Pclass=%{x}<br>Fare=%{y}<extra></extra>',
              'legendgroup': '',
              'line': {'color': '#636efa', 'dash': 'solid'},
              'marker': {'symbol': 'circle'},
              'mode': 'lines',
              'name': '',
              'orientation': 'v',
              'showlegend': False,
              'type': 'scatter',
              'x': array([1, 2, 3]),
              'xaxis': 'x',
              'y': array([84.1546875 , 20.66218315, 13.70770739]),
              'yaxis': 'y'}],
    'layout': {'legend': {'tracegroupgap': 0},
               'margin': {'t': 60},
               'template': '...',
               'xaxis': {'anchor': 'y', 'domain': [0.0, 1.0], 'title': {'text': 'Pclass'}},
               '


The default value of numeric_only in DataFrameGroupBy.mean is deprecated. In a future version, numeric_only will default to False. Either specify numeric_only or select only columns which should be valid for the function.



You can even specify the plot type you would like!

In [21]:
output = df_chain("Plot the average fare per class (bar)")
output["result"]



[1m> Entering new PandasDataFrameChain chain...[0m
Question: [36;1m[1;3mPlot the average fare per class (bar)[0m
Code: [33;1m[1;3mpx.bar(df.groupby('Pclass').mean().reset_index(), x='Pclass', y='Fare', barmode='group')[0m
Result: [32;1m[1;3mFigure({
    'data': [{'alignmentgroup': 'True',
              'hovertemplate': 'Pclass=%{x}<br>Fare=%{y}<extra></extra>',
              'legendgroup': '',
              'marker': {'color': '#636efa', 'pattern': {'shape': ''}},
              'name': '',
              'offsetgroup': '',
              'orientation': 'v',
              'showlegend': False,
              'textposition': 'auto',
              'type': 'bar',
              'x': array([1, 2, 3]),
              'xaxis': 'x',
              'y': array([84.1546875 , 20.66218315, 13.70770739]),
              'yaxis': 'y'}],
    'layout': {'barmode': 'group',
               'legend': {'tracegroupgap': 0},
               'margin': {'t': 60},
               'template': '...',
          


The default value of numeric_only in DataFrameGroupBy.mean is deprecated. In a future version, numeric_only will default to False. Either specify numeric_only or select only columns which should be valid for the function.

