# Microsoft LIDA approach with Langchain

The [Paper](https://arxiv.org/abs/2303.02927) __LIDA: Automatic Generation of Visualizations and Infographics using Large Language Models__ by [Victor Dibia](https://github.com/victordibia) shows how LLMs can be used to generate automatic data visualizations. Unfortunately, the [GitHub](https://github.com/microsoft/lida) repository is not up-to-date and cannot be run with the supported LLMs due to version conflicts and dependencies.

In order to be able to use different language models, [Victor Dibia](https://github.com/victordibia) has also implemented his own abstraction framework [LLMX](https://github.com/victordibia/llmx), which is unfortunately no longer usable due to a newer OpenAI API version.

For this reason, I have copied parts of the framework within this notebook in order to be able to better control the LIDA process. This implementation uses Langchain for the LLM handling.

## Install required packages

In [4]:
!pip install --quiet langchain
!pip install --quiet openai

### Import required modules and classes

In [5]:
import os
import logging
import sys
import textwrap
import pprint
import pandas as pd
from langchain.docstore.document import Document
from langchain.document_loaders import TextLoader
from langchain.embeddings.azure_openai import AzureOpenAIEmbeddings
from langchain_community.chat_models import AzureChatOpenAI
from langchain.llms import AzureOpenAI
from langchain.prompts import ChatPromptTemplate
from langchain.prompts import PromptTemplate
from langchain.schema import HumanMessage
import getpass
import json
import re
import warnings
import importlib
import ast
import matplotlib.pyplot as plt

In [6]:
logger = logging.getLogger(__name__)

### Enter your Azure OpenAI GPT-4 key and endpoint URL

In [7]:
print('Enter AZURE_OPENAI_API_KEY:')
os.environ["AZURE_OPENAI_API_KEY"] = getpass.getpass()

Enter AZURE_OPENAI_API_KEY:
··········


In [8]:
print('Enter AZURE_OPENAI_ENDPOINT:')
os.environ["AZURE_OPENAI_ENDPOINT"] = getpass.getpass()

Enter AZURE_OPENAI_ENDPOINT:
··········


### Some helper functions

In [9]:
def clean_code_snippet(code_string):
    # Extract code snippet using regex
    cleaned_snippet = re.search(r'```(?:\w+)?\s*([\s\S]*?)\s*```', code_string)

    if cleaned_snippet:
        cleaned_snippet = cleaned_snippet.group(1)
    else:
        cleaned_snippet = code_string

    # remove non-printable characters
    # cleaned_snippet = re.sub(r'[\x00-\x1F]+', ' ', cleaned_snippet)

    return cleaned_snippet

In [10]:
def clean_column_names(df):
    # create a copy of the dataframe to avoid modifying the original data
    cleaned_df = df.copy()

    # iterate over column names in the dataframe
    for col in cleaned_df.columns:
        # check if column name contains any special characters or spaces
        if re.search('[^0-9a-zA-Z_]', col):
            # replace special characters and spaces with underscores
            new_col = re.sub('[^0-9a-zA-Z_]', '_', col)
            # rename the column in the cleaned dataframe
            cleaned_df.rename(columns={col: new_col}, inplace=True)

    # return the cleaned dataframe
    return cleaned_df

In [11]:
def read_dataframe(file_location):
    file_extension = file_location.split('.')[-1]
    if file_extension == 'json':
        try:
            df = pd.read_json(file_location, orient='records')
        except ValueError:
            df = pd.read_json(file_location, orient='table')
    elif file_extension == 'csv':
        df = pd.read_csv(file_location)
    elif file_extension in ['xls', 'xlsx']:
        df = pd.read_excel(file_location)
    elif file_extension == 'parquet':
        df = pd.read_parquet(file_location)
    elif file_extension == 'feather':
        df = pd.read_feather(file_location)
    elif file_extension == "tsv":
        df = pd.read_csv(file_location, sep="\t")
    elif file_extension == "txt":
        df = pd.read_csv(file_location)
    else:
        raise ValueError('Unsupported file type')

    # clean column names and check if they have changed
    cleaned_df = clean_column_names(df)
    if cleaned_df.columns.tolist() != df.columns.tolist() or len(df) > 4500:
        if len(df) > 4500:
            logger.info(f"Dataframe has more than 4500 rows. We will sample 4500 rows.")
            cleaned_df = cleaned_df.sample(4500)
        # write the cleaned DataFrame to the original file on disk
        if file_extension == 'csv':
            cleaned_df.to_csv(file_location, index=False)
        elif file_extension == 'txt':
            cleaned_df.to_csv(file_location, index=False)
        elif file_extension in ['xls', 'xlsx']:
            cleaned_df.to_excel(file_location, index=False)
        elif file_extension == 'parquet':
            cleaned_df.to_parquet(file_location, index=False)
        elif file_extension == 'feather':
            cleaned_df.to_feather(file_location, index=False)
        elif file_extension == 'json':
            with open(file_location, 'w') as f:
                f.write(cleaned_df.to_json(orient='records'))
        else:
            raise ValueError('Unsupported file type')

    return cleaned_df

In [12]:
def check_type(dtype: str, value):
        """Cast value to right type to ensure it is JSON serializable"""
        if "float" in str(dtype):
            return float(value)
        elif "int" in str(dtype):
            return int(value)
        else:
            return value

In [13]:
def get_column_properties(df: pd.DataFrame, n_samples: int = 3) -> list[dict]:
  """Get properties of each column in a pandas DataFrame"""
  properties_list = []
  for column in df.columns:
      dtype = df[column].dtype
      properties = {}
      if dtype in [int, float, complex]:
          properties["dtype"] = "number"
          properties["std"] = check_type(dtype, df[column].std())
          properties["min"] = check_type(dtype, df[column].min())
          properties["max"] = check_type(dtype, df[column].max())

      elif dtype == bool:
          properties["dtype"] = "boolean"
      elif dtype == object:
          # Check if the string column can be cast to a valid datetime
          try:
              with warnings.catch_warnings():
                  warnings.simplefilter("ignore")
                  pd.to_datetime(df[column], errors='raise')
                  properties["dtype"] = "date"
          except ValueError:
              # Check if the string column has a limited number of values
              if df[column].nunique() / len(df[column]) < 0.5:
                  properties["dtype"] = "category"
              else:
                  properties["dtype"] = "string"
      elif pd.api.types.is_categorical_dtype(df[column]):
          properties["dtype"] = "category"
      elif pd.api.types.is_datetime64_any_dtype(df[column]):
          properties["dtype"] = "date"
      else:
          properties["dtype"] = str(dtype)

      # add min max if dtype is date
      if properties["dtype"] == "date":
          try:
              properties["min"] = df[column].min()
              properties["max"] = df[column].max()
          except TypeError:
              cast_date_col = pd.to_datetime(df[column], errors='coerce')
              properties["min"] = cast_date_col.min()
              properties["max"] = cast_date_col.max()
      # Add additional properties to the output dictionary
      nunique = df[column].nunique()
      if "samples" not in properties:
          non_null_values = df[column][df[column].notnull()].unique()
          n_samples = min(n_samples, len(non_null_values))
          samples = pd.Series(non_null_values).sample(n_samples, random_state=42).tolist()
          properties["samples"] = samples
      properties["num_unique_values"] = nunique
      properties["semantic_type"] = ""
      properties["description"] = ""
      properties_list.append({"column": column, "properties": properties})

  return properties_list

In [14]:
def preprocess_code(code: str) -> str:
    """Preprocess code to remove any preamble and explanation text"""

    code = code.replace("<imports>", "")
    code = code.replace("<stub>", "")
    code = code.replace("<transforms>", "")

    # remove all text after chart = plot(data)
    if "chart = plot(data)" in code:
        # print(code)
        index = code.find("chart = plot(data)")
        if index != -1:
            code = code[: index + len("chart = plot(data)")]

    if "```" in code:
        pattern = r"```(?:\w+\n)?([\s\S]+?)```"
        matches = re.findall(pattern, code)
        if matches:
            code = matches[0]
        # code = code.replace("```", "")
        # return code

    if "import" in code:
        # return only text after the first import statement
        index = code.find("import")
        if index != -1:
            code = code[index:]

    code = code.replace("```", "")
    if "chart = plot(data)" not in code:
        code = code + "\nchart = plot(data)"
    return code

In [15]:
def get_globals_dict(code_string, data):
    # Parse the code string into an AST
    tree = ast.parse(code_string)
    # Extract the names of the imported modules and their aliases
    imported_modules = []
    for node in tree.body:
        if isinstance(node, ast.Import):
            for alias in node.names:
                module = importlib.import_module(alias.name)
                imported_modules.append((alias.name, alias.asname, module))
        elif isinstance(node, ast.ImportFrom):
            module = importlib.import_module(node.module)
            for alias in node.names:
                obj = getattr(module, alias.name)
                imported_modules.append(
                    (f"{node.module}.{alias.name}", alias.asname, obj)
                )

    # Import the required modules into a dictionary
    globals_dict = {}
    for module_name, alias, obj in imported_modules:
        if alias:
            globals_dict[alias] = obj
        else:
            globals_dict[module_name.split(".")[-1]] = obj

    ex_dicts = {"pd": pd, "data": data, "plt": plt}
    globals_dict.update(ex_dicts)
    return globals_dict

In [16]:
chat_llm = AzureChatOpenAI(
    temperature=0,
    openai_api_version="2023-05-15",
    deployment_name="gpt-4",
)

### Generating the data SUMMARY

In [17]:
#filename = "https://raw.githubusercontent.com/uwdata/draco/master/data/cars.csv"
filename = "https://raw.githubusercontent.com/manavpatel1092/EDA-of-Telecom-Churn-rate/master/churn1.txt"

In [18]:
data = read_dataframe(filename)

In [20]:
data.head()

Unnamed: 0,State,Account_Length,Area_Code,Phone,Int_l_Plan,VMail_Plan,VMail_Message,Day_Mins,Day_Calls,Day_Charge,...,Eve_Calls,Eve_Charge,Night_Mins,Night_Calls,Night_Charge,Intl_Mins,Intl_Calls,Intl_Charge,CustServ_Calls,Churn_
0,KS,128,415,382-4657,no,yes,25,265.1,110,45.07,...,99,16.78,244.7,91,11.01,10.0,3,2.7,1,False.
1,OH,107,415,371-7191,no,yes,26,161.6,123,27.47,...,103,16.62,254.4,103,11.45,13.7,3,3.7,1,False.
2,NJ,137,415,358-1921,no,no,0,243.4,114,41.38,...,110,10.3,162.6,104,7.32,12.2,5,3.29,0,False.
3,OH,84,408,375-9999,yes,no,0,299.4,71,50.9,...,88,5.26,196.9,89,8.86,6.6,7,1.78,2,False.
4,OK,75,415,330-6626,yes,no,0,166.7,113,28.34,...,122,12.61,186.9,121,8.41,10.1,3,2.73,3,False.


In [21]:
data_properties = get_column_properties(data, 3)

In [22]:
data_properties

[{'column': 'State',
  'properties': {'dtype': 'category',
   'samples': ['DC', 'WA', 'MS'],
   'num_unique_values': 51,
   'semantic_type': '',
   'description': ''}},
 {'column': 'Account_Length',
  'properties': {'dtype': 'number',
   'std': 39,
   'min': 1,
   'max': 243,
   'samples': [172, 189, 44],
   'num_unique_values': 212,
   'semantic_type': '',
   'description': ''}},
 {'column': 'Area_Code',
  'properties': {'dtype': 'number',
   'std': 42,
   'min': 408,
   'max': 510,
   'samples': [415, 408, 510],
   'num_unique_values': 3,
   'semantic_type': '',
   'description': ''}},
 {'column': 'Phone',
  'properties': {'dtype': 'string',
   'samples': ['352-6573', '369-4377', '392-2555'],
   'num_unique_values': 3333,
   'semantic_type': '',
   'description': ''}},
 {'column': 'Int_l_Plan',
  'properties': {'dtype': 'category',
   'samples': ['yes', 'no'],
   'num_unique_values': 2,
   'semantic_type': '',
   'description': ''}},
 {'column': 'VMail_Plan',
  'properties': {'dtype'

In [23]:
system_prompt = """
You are an experienced data analyst that can annotate datasets. Your instructions are as follows:
1. ALWAYS generate the name of the dataset and the dataset_description
2. ALWAYS generate a field description.
3. ALWAYS generate a semantic_type (a single word) for each field given its values e.g. company, city, number, supplier, location, gender, longitude, latitude, url, ip address, zip code, email, etc
You must return an updated JSON dictionary without any preamble or explanation.
"""

base_summary = {
    "file_name": filename,
    "fields": data_properties
}

chat_template = ChatPromptTemplate.from_messages(
    [
        ("system", "{system_prompt}."),
        ("human", "Annotate the dictionary below. Only return a JSON object. {base_summary}"),
    ]
)

messages = chat_template.format_messages(system_prompt=system_prompt, base_summary=base_summary)

In [24]:
print(messages)

[SystemMessage(content='\nYou are an experienced data analyst that can annotate datasets. Your instructions are as follows:\n1. ALWAYS generate the name of the dataset and the dataset_description\n2. ALWAYS generate a field description.\n3. ALWAYS generate a semantic_type (a single word) for each field given its values e.g. company, city, number, supplier, location, gender, longitude, latitude, url, ip address, zip code, email, etc\nYou must return an updated JSON dictionary without any preamble or explanation.\n.'), HumanMessage(content="Annotate the dictionary below. Only return a JSON object. {'file_name': 'https://raw.githubusercontent.com/manavpatel1092/EDA-of-Telecom-Churn-rate/master/churn1.txt', 'fields': [{'column': 'State', 'properties': {'dtype': 'category', 'samples': ['DC', 'WA', 'MS'], 'num_unique_values': 51, 'semantic_type': '', 'description': ''}}, {'column': 'Account_Length', 'properties': {'dtype': 'number', 'std': 39, 'min': 1, 'max': 243, 'samples': [172, 189, 44],

In [25]:
res = chat_llm.invoke(messages)

In [26]:
try:
  json_string = clean_code_snippet(res.content)
  enriched_summary = json.loads(json_string)
except json.decoder.JSONDecodeError:
  error_msg = f"The model did not return a valid JSON object while attempting to generate an enriched data summary. Consider using a default summary or a larger model with higher max token length. | {res.content}"
  logger.info(error_msg)


In [27]:
data_summary = enriched_summary
data_summary["field_names"] = data.columns.tolist()
data_summary["file_name"] = filename

In [28]:
data_summary

{'dataset_name': 'Telecom Churn Rate Dataset',
 'dataset_description': 'This dataset contains information about telecom customers, their usage statistics, and whether they have churned or not.',
 'fields': [{'column': 'State',
   'properties': {'dtype': 'category',
    'samples': ['DC', 'WA', 'MS'],
    'num_unique_values': 51,
    'semantic_type': 'state',
    'description': 'The state where the customer resides.'}},
  {'column': 'Account_Length',
   'properties': {'dtype': 'number',
    'std': 39,
    'min': 1,
    'max': 243,
    'samples': [172, 189, 44],
    'num_unique_values': 212,
    'semantic_type': 'number',
    'description': "The length of the customer's account in days."}},
  {'column': 'Area_Code',
   'properties': {'dtype': 'number',
    'std': 42,
    'min': 408,
    'max': 510,
    'samples': [415, 408, 510],
    'num_unique_values': 3,
    'semantic_type': 'area_code',
    'description': "The area code of the customer's phone number."}},
  {'column': 'Phone',
   'pro

### Generating the GOAL(s)

In [29]:
goal_system_prompt = """
You are a an experienced data analyst who can generate a given number of insightful GOALS about data, when given a summary of the data,
and a specified persona. The VISUALIZATIONS YOU RECOMMEND MUST FOLLOW VISUALIZATION BEST PRACTICES (e.g., must use bar charts instead of pie charts for
comparing quantities) AND BE MEANINGFUL (e.g., plot longitude and latitude on maps where appropriate). They must also be relevant to the specified persona.
Each goal must include a question, a visualization (THE VISUALIZATION MUST REFERENCE THE EXACT COLUMN FIELDS FROM THE SUMMARY), and a rationale (JUSTIFICATION FOR WHICH dataset
FIELDS ARE USED and what we will learn from the visualization). Each goal MUST mention the exact fields from the dataset summary above
"""

goal_format_instruction_prompt = """
THE OUTPUT MUST BE A CODE SNIPPET OF A VALID LIST OF JSON OBJECTS. IT MUST USE THE FOLLOWING FORMAT:

```[
    { "index": 0,  "question": "What is the distribution of X", "visualization": "histogram of X", "rationale": "This tells about "} ..
    ]
```
THE OUTPUT SHOULD ONLY USE THE JSON FORMAT ABOVE.
"""

goal_user_prompt = """
The number of GOALS to generate is {number_of_goals}. The goals should be based on the data summary below, \n\n.
"""

goal_chat_template = ChatPromptTemplate.from_messages(
    [
        ("system", "{goal_system_prompt}."),
        ("human", "The number of GOALS to generate is {number_of_goals}. The goals should be based on the data summary below, \n\n {data_summary} \n\n {goal_format_instruction_prompt}. The generated {number_of_goals} goals are: \n"),
    ]
)

goal_messages = goal_chat_template.format_messages(
    goal_system_prompt=goal_system_prompt,
    number_of_goals=10,
    data_summary=data_summary,
    goal_format_instruction_prompt=goal_format_instruction_prompt
)

In [30]:
print(goal_messages)

[SystemMessage(content='\nYou are a an experienced data analyst who can generate a given number of insightful GOALS about data, when given a summary of the data,\nand a specified persona. The VISUALIZATIONS YOU RECOMMEND MUST FOLLOW VISUALIZATION BEST PRACTICES (e.g., must use bar charts instead of pie charts for\ncomparing quantities) AND BE MEANINGFUL (e.g., plot longitude and latitude on maps where appropriate). They must also be relevant to the specified persona.\nEach goal must include a question, a visualization (THE VISUALIZATION MUST REFERENCE THE EXACT COLUMN FIELDS FROM THE SUMMARY), and a rationale (JUSTIFICATION FOR WHICH dataset\nFIELDS ARE USED and what we will learn from the visualization). Each goal MUST mention the exact fields from the dataset summary above\n.'), HumanMessage(content='The number of GOALS to generate is 10. The goals should be based on the data summary below, \n\n {\'dataset_name\': \'Telecom Churn Rate Dataset\', \'dataset_description\': \'This datase

In [31]:
goal_res = chat_llm.invoke(goal_messages)

In [32]:
goal_res.content

'```[\n    { "index": 0,  "question": "What is the distribution of account lengths?", "visualization": "Histogram of Account_Length", "rationale": "This tells us about the range and frequency of account lengths, which can help us understand the typical duration a customer stays with the telecom company." },\n    { "index": 1,  "question": "What is the churn rate per state?", "visualization": "Bar chart of Churn_ grouped by State", "rationale": "This can help us identify if there are any geographical patterns to customer churn." },\n    { "index": 2,  "question": "How does the number of customer service calls relate to churn?", "visualization": "Box plot of CustServ_Calls for each category of Churn_", "rationale": "This can help us understand if there is a relationship between the number of customer service calls and customer churn." },\n    { "index": 3,  "question": "What is the relationship between day minutes and day charge?", "visualization": "Scatter plot of Day_Mins vs Day_Charge

In [33]:
try:
  json_string = clean_code_snippet(goal_res.content)
  result = json.loads(json_string)
  if isinstance(result, dict):
    result = [result]
except json.decoder.JSONDecodeError:
  logger.info(f"Error decoding JSON: {goal_res.content}")
  print(f"Error decoding JSON: {goal_res.content}")

In [34]:
result

[{'index': 0,
  'question': 'What is the distribution of account lengths?',
  'visualization': 'Histogram of Account_Length',
  'rationale': 'This tells us about the range and frequency of account lengths, which can help us understand the typical duration a customer stays with the telecom company.'},
 {'index': 1,
  'question': 'What is the churn rate per state?',
  'visualization': 'Bar chart of Churn_ grouped by State',
  'rationale': 'This can help us identify if there are any geographical patterns to customer churn.'},
 {'index': 2,
  'question': 'How does the number of customer service calls relate to churn?',
  'visualization': 'Box plot of CustServ_Calls for each category of Churn_',
  'rationale': 'This can help us understand if there is a relationship between the number of customer service calls and customer churn.'},
 {'index': 3,
  'question': 'What is the relationship between day minutes and day charge?',
  'visualization': 'Scatter plot of Day_Mins vs Day_Charge',
  'ratio

In [35]:
result[7]

{'index': 7,
 'question': 'How does the number of voicemail messages relate to churn?',
 'visualization': 'Box plot of VMail_Message for each category of Churn_',
 'rationale': 'This can help us understand if there is a relationship between the number of voicemail messages and customer churn.'}

### Visualize

In [36]:
# library_template, library_instructions = self.scaffold.get_template(goal, library)

library = "plotly"

library_template = """
import plotly.express as px
<imports>
def plot(data: pd.DataFrame):
    fig = <stub> # only modify this section

    return chart
chart = plot(data) # variable data already contains the data to be plotted and should not be loaded again.  Always include this line. No additional code beyond this line..
"""

viz_chat_template = ChatPromptTemplate.from_messages(
    [
        ("system", "You are a helpful assistant highly skilled in writing PERFECT code for visualizations. Given some code template, you complete the template to generate a visualization given the dataset and the goal described. The code you write MUST FOLLOW VISUALIZATION BEST PRACTICES ie. meet the specified goal, apply the right transformation, use the right visualization type, use the right data encoding, and use the right aesthetics (e.g., ensure axis are legible). The transformations you apply MUST be correct and the fields you use MUST be correct. The visualization CODE MUST BE CORRECT and MUST NOT CONTAIN ANY SYNTAX OR LOGIC ERRORS (e.g., it must consider the field types and use them correctly). You MUST first generate a brief plan for how you would solve the task e.g. what transformations you would apply e.g. if you need to construct a new column, what fields you would use, what visualization type you would use, what aesthetics you would use, etc. The dataset summary is : {data_summary}. You are a helpful assistant highly skilled in writing PERFECT code for visualizations. Given some code template, you complete the template to generate a visualization given the dataset and the goal described. The code you write MUST FOLLOW VISUALIZATION BEST PRACTICES ie. meet the specified goal, apply the right transformation, use the right visualization type, use the right data encoding, and use the right aesthetics (e.g., ensure axis are legible). The transformations you apply MUST be correct and the fields you use MUST be correct. The visualization CODE MUST BE CORRECT and MUST NOT CONTAIN ANY SYNTAX OR LOGIC ERRORS (e.g., it must consider the field types and use them correctly). You MUST first generate a brief plan for how you would solve the task e.g. what transformations you would apply e.g. if you need to construct a new column, what fields you would use, what visualization type you would use, what aesthetics you would use, etc.If the solution requires a single value (e.g. max, min, median, first, last etc), ALWAYS add a line (axvline or axhline) to the chart, ALWAYS with a legend containing the single value (formatted with 0.2F). If using a <field> where semantic_type=date, YOU MUST APPLY the following transform before using that column i) convert date fields to date types using data[''] = pd.to_datetime(data[<field>], errors='coerce'), ALWAYS use  errors='coerce' ii) drop the rows with NaT values data = data[pd.notna(data[<field>])] iii) convert field to right time format for plotting.  ALWAYS make sure the x-axis labels are legible (e.g., rotate when needed). Solve the task  carefully by completing ONLY the <imports> AND <stub> section. Given the dataset summary, the plot(data) method should generate a {library} chart ({goal_visualization}) that addresses this goal: {goal_question}. DO NOT WRITE ANY CODE TO LOAD THE DATA. The data is already loaded and available in the variable data. If calculating metrics such as mean, median, mode, etc. ALWAYS use the option 'numeric_only=True' when applicable and available, AVOID visualizations that require nbformat library. DO NOT inlcude fig.show(). The plot method must return an plotly figure object (fig)`. Think step by step. \n."),
        ("human", "Always add a legend with various colors where appropriate. The visualization code MUST only use data fields that exist in the dataset (field_names) or fields that are transformations based on existing field_names). Only use variables that have been defined in the code or are in the dataset summary. You MUST return a FULL PYTHON PROGRAM ENCLOSED IN BACKTICKS ``` that starts with an import statement. DO NOT add any explanation. \n\n THE GENERATED CODE SOLUTION SHOULD BE CREATED BY MODIFYING THE SPECIFIED PARTS OF THE TEMPLATE BELOW \n\n {library_template} \n\n.The FINAL COMPLETED CODE BASED ON THE TEMPLATE above is ... \n\n")
    ]
)

viz_messages = viz_chat_template.format_messages(
    library=library,
    library_template=library_template,
    data_summary=data_summary,
    goal_question=result[7]['question'],
    goal_visualization=result[7]['visualization']
)


In [37]:
viz_res = chat_llm.invoke(viz_messages)


In [38]:
viz_res.content

"```python\nimport plotly.express as px\nimport pandas as pd\n\ndef plot(data: pd.DataFrame):\n    fig = px.box(data, x='Churn_', y='VMail_Message', color='Churn_', \n                 labels={'Churn_':'Churn', 'VMail_Message':'Number of Voicemail Messages'},\n                 title='Box plot of VMail_Message for each category of Churn')\n    return fig\nchart = plot(data)\n```"

In [39]:
code = viz_res.content.replace("```python", '')
code = code.replace("```", '')


In [40]:
ex_locals = get_globals_dict(code, data)

In [41]:
exec(code, ex_locals)

In [42]:
chart = ex_locals["chart"]

In [43]:
chart

### Visualization Explainer

In [49]:
explainer_system_prompt = """
You are a helpful assistant highly skilled in providing helpful, structured explanations of visualization of the plot(data: pd.DataFrame) method in the provided code. You divide the code into sections and provide a description of each section and an explanation. The first section should be named "accessibility" and describe the physical appearance of the chart (colors, chart type etc), the goal of the chart, as well the main insights from the chart.
You can explain code across the following 3 dimensions:
1. accessibility: the physical appearance of the chart (colors, chart type etc), the goal of the chart, as well the main insights from the chart.
2. transformation: This should describe the section of the code that applies any kind of data transformation (filtering, aggregation, grouping, null value handling etc)
3. visualization: step by step description of the code that creates or modifies the presented visualization.
"""

explainer_format_instructions = """
Your output MUST be perfect JSON in THE FORM OF A VALID LIST of JSON OBJECTS WITH PROPERLY ESCAPED SPECIAL CHARACTERS e.g.,

```[
    {"section": "accessibility", "code": "None", "explanation": ".."}  , {"section": "transformation", "code": "..", "explanation": ".."}  ,  {"section": "visualization", "code": "..", "explanation": ".."}
    ] ```

The code part of the dictionary must come from the supplied code and should cover the explanation. The explanation part of the dictionary must be a string. The section part of the dictionary must be one of "accessibility", "transformation", "visualization" with no repetition. THE LIST MUST HAVE EXACTLY 3 JSON OBJECTS [{}, {}, {}].  THE GENERATED JSON  MUST BE A LIST IE START AND END WITH A SQUARE BRACKET.
"""

explainer_chat_template = ChatPromptTemplate.from_messages(
    [
        ("system", "{system_prompt}"),
        ("assistant", "The code to be explained is {code}.\n=======\n"),
        ("human", "{format_instructions}.\n\n. The structured explanation for the code above is \n\n")
    ]
)

explainer_messages = explainer_chat_template.format_messages(
    system_prompt=explainer_system_prompt,
    format_instructions=explainer_format_instructions,
    code=code
)

In [50]:
explainer_res = chat_llm.invoke(explainer_messages)

In [69]:
explainer_res.content

'```[\n    {"section": "accessibility", "code": "fig = px.box(data, x=\'Churn_\', y=\'VMail_Message\', color=\'Churn_\', labels={\'Churn_\':\'Churn\', \'VMail_Message\':\'Number of Voicemail Messages\'}, title=\'Box plot of VMail_Message for each category of Churn\')", "explanation": "The code generates a box plot, which is a standardized way of displaying the distribution of data based on a five number summary (\'minimum\', first quartile (Q1), median, third quartile (Q3), and \'maximum\'). The color of the boxes is determined by the \'Churn_\' column. The x-axis represents the \'Churn_\' categories and the y-axis represents the \'VMail_Message\' values. The labels are customized to be more descriptive. The goal of the chart is to visualize the distribution of the \'VMail_Message\' for each category of \'Churn_\'."}, \n    {"section": "transformation", "code": "None", "explanation": "There is no data transformation in the provided code. The function takes a pandas DataFrame as input a

In [72]:
explanation = clean_code_snippet(explainer_res.content)

In [73]:
try:
  exp = json.loads(explanation)
except Exception as e:
  print("Error parsing completion", explainer_res.content, str(e))

In [78]:
for e in exp:
  print(e['explanation'])

The code generates a box plot, which is a standardized way of displaying the distribution of data based on a five number summary ('minimum', first quartile (Q1), median, third quartile (Q3), and 'maximum'). The color of the boxes is determined by the 'Churn_' column. The x-axis represents the 'Churn_' categories and the y-axis represents the 'VMail_Message' values. The labels are customized to be more descriptive. The goal of the chart is to visualize the distribution of the 'VMail_Message' for each category of 'Churn_'.
There is no data transformation in the provided code. The function takes a pandas DataFrame as input and directly uses it for plotting without any filtering, aggregation, grouping, or null value handling.
The code starts by importing the necessary libraries, plotly.express and pandas. Then, it defines a function named 'plot' that takes a pandas DataFrame as an argument. Inside this function, it uses the 'box' function from plotly.express to create a box plot. The 'box'