In [11]:
import dspy
from dspy.teleprompt import *
from dspy.evaluate import Evaluate

os.environ['TOGETHER_API_KEY'] = ''
os.environ['TOGETHER_API_BASE'] = 'https://api.together.xyz/v1'

In [12]:
# graphs is a list of strings where each string is the python code
# graph_descriptions is a list of strings where each string is the description for its corresponding graph at the same index in 'graphs'
# data_descriptions is a list of strings where each string is the description of the data

data_descriptions = [
    "Fields: sepal_length (float), sepal_width (float), petal_length (float), petal_width (float), species (string). Format: CSV",
    "Fields: survived (integer), pclass (integer), name (string), sex (string), age (float), sibsp (integer), parch (integer), ticket (string), fare (float), cabin (string), embarked (string). Format: CSV",
    "Fields: Month (string), Temperature (float), Rainfall (float), Humidity (float). Format: CSV",
    "Fields: species (string), island (string), bill_length_mm (float), bill_depth_mm (float), flipper_length_mm (float), body_mass_g (float), sex (string). Format: CSV",
    "Fields: fixed_acidity (float), volatile_acidity (float), citric_acid (float), residual_sugar (float), chlorides (float), free_sulfur_dioxide (float), total_sulfur_dioxide (float), density (float), pH (float), sulphates (float), alcohol (float), quality (integer). Format: CSV",
    "Fields: year (integer), month (string), passengers (integer). Format: CSV",
    "Fields: Product (string), Q1 Sales (integer), Q2 Sales (integer), Q3 Sales (integer), Q4 Sales (integer), Price (float). Format: CSV",
    "Fields: age (float), sex (integer), cp (integer), trestbps (integer), chol (integer), fbs (integer), restecg (integer), thalach (integer), exang (integer), oldpeak (float), slope (integer), ca (integer), thal (integer), target (integer). Format: CSV",
    "Fields: tconst (string), averageRating (float), numVotes (integer), titleType (string), primaryTitle (string), originalTitle (string), isAdult (string), startYear (integer), endYear (string), runtimeMinutes (string), genres (string). Format: TSV",
    "Fields: age (float), sex (integer), bmi (float), bp (float), s1 (float), s2 (float), s3 (float), s4 (float), s5 (float), s6 (float), target (float). Format: Dataset"
]


graph_descriptions = [
    "Scatter plot of sepal length vs sepal width, with a new column for sepal area.",
    "Bar plot of survival rate by passenger class, with imputed missing age values and categorized age groups.",
    "Line plot of monthly temperature, with a secondary y-axis bar chart for rainfall, and scatter plot for average humidity.",
    "Pair plot of penguin measurements, with imputed missing values.",
    "Histogram of wine quality, with a new column for categorized alcohol levels.",
    "Line plot of cumulative number of passengers over time.",
    "Grouped bar chart of quarterly sales for each product, with a line plot overlay for total annual sales.",
    "Box plot of cholesterol levels by chest pain type, with outlier detection and removal.",
    "Bar plot of average movie ratings by year, with timestamp conversion and year extraction.",
    "Pair plot of selected features from the diabetes dataset, color-coded by target value quartiles."
]


graphs = [
    """
import seaborn as sns
import matplotlib.pyplot as plt

# Load the dataset
iris = sns.load_dataset('iris')

# Data preprocessing
iris['sepal_area'] = iris['sepal_length'] * iris['sepal_width']

# Plotting
sns.scatterplot(data=iris, x='sepal_length', y='sepal_width', hue='species', size='sepal_area', sizes=(20, 200))
plt.title('Sepal Length vs Sepal Width by Species')
plt.show()
    """,
    """
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Load the dataset
titanic = sns.load_dataset('titanic')

# Data preprocessing
titanic['age'] = titanic['age'].fillna(titanic['age'].median())
titanic['AgeGroup'] = pd.cut(titanic['age'], bins=[0, 12, 18, 35, 60, 80], labels=['Child', 'Teenager', 'Adult', 'Middle-aged', 'Senior'])

# Plotting
sns.barplot(data=titanic, x='pclass', y='survived', hue='AgeGroup')
plt.title('Survival Rate by Passenger Class and Age Group')
plt.show()
    """,
    """
import pandas as pd
import matplotlib.pyplot as plt

# Create example data
data = {
    'Month': ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'],
    'Temperature': [5, 7, 10, 15, 20, 25, 30, 28, 22, 15, 10, 6],
    'Rainfall': [60, 50, 55, 70, 80, 90, 100, 95, 85, 70, 60, 65],
    'Humidity': [80, 75, 70, 65, 60, 55, 50, 52, 58, 65, 70, 75]
}

# Save to CSV
df = pd.DataFrame(data)
df.to_csv('data.csv', index=False)

# Data preprocessing
data = pd.read_csv('data.csv')
monthly_avg_humidity = data.groupby('Month')['Humidity'].mean().reindex(['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])

# Plotting
fig, ax1 = plt.subplots()

ax2 = ax1.twinx()
ax1.plot(data['Month'], data['Temperature'], 'r-', label='Temperature')
ax2.bar(data['Month'], data['Rainfall'], alpha=0.5, label='Rainfall')
ax2.scatter(monthly_avg_humidity.index, monthly_avg_humidity, color='b', label='Avg Humidity')

ax1.set_xlabel('Month')
ax1.set_ylabel('Temperature (C)', color='r')
ax2.set_ylabel('Rainfall (mm) and Avg Humidity (%)', color='b')
ax1.legend(loc='upper left')
ax2.legend(loc='upper right')

plt.title('Monthly Temperature, Rainfall, and Humidity')
plt.show()
    """,
    """
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Load the dataset
penguins = sns.load_dataset('penguins')

# Data preprocessing
penguins.fillna(penguins.mean(numeric_only=True), inplace=True)

# Plotting
sns.pairplot(data=penguins, hue='species')
plt.suptitle('Pair Plot of Penguin Measurements with Imputed Missing Values', y=1)
plt.show()
    """,
    """
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Load the dataset
red_wine = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-red.csv', sep=';')
white_wine = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv', sep=';')

# Data preprocessing
red_wine['alcohol_level'] = pd.cut(red_wine['alcohol'], bins=[0, 10, 12, 14], labels=['Low', 'Medium', 'High'])
white_wine['alcohol_level'] = pd.cut(white_wine['alcohol'], bins=[0, 10, 12, 14], labels=['Low', 'Medium', 'High'])

# Plotting
sns.histplot(red_wine['quality'], kde=True, color='red', label='Red Wine')
sns.histplot(white_wine['quality'], kde=True, color='yellow', label='White Wine')
plt.title('Wine Quality Distribution with Alcohol Levels')
plt.legend()
plt.show()
    """,
    """
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Load the dataset
flights = sns.load_dataset('flights')

# Data preprocessing
flights['cumulative_passengers'] = flights['passengers'].cumsum()

# Plotting
sns.lineplot(data=flights, x='year', y='cumulative_passengers')
plt.title('Cumulative Number of Passengers Over Time')
plt.show()
    """,
    """
import pandas as pd

# Create example data
data = {
    'Product': ['A', 'B', 'C', 'D', 'E'],
    'Q1 Sales': [150, 200, 300, 400, 250],
    'Q2 Sales': [160, 210, 320, 420, 260],
    'Q3 Sales': [170, 220, 340, 440, 270],
    'Q4 Sales': [180, 230, 360, 460, 280],
    'Price': [10, 15, 20, 25, 30]
}

# Save to CSV
df = pd.DataFrame(data)
df.to_csv('data.csv', index=False)

import pandas as pd
import matplotlib.pyplot as plt

# Data preprocessing
data = pd.read_csv('data.csv')
data.set_index('Product', inplace=True)
data['Total Sales'] = data[['Q1 Sales', 'Q2 Sales', 'Q3 Sales', 'Q4 Sales']].sum(axis=1)

# Plotting
ax = data[['Q1 Sales', 'Q2 Sales', 'Q3 Sales', 'Q4 Sales']].plot(kind='bar')
ax2 = ax.twinx()
ax2.plot(data.index, data['Total Sales'], 'r-', marker='o', label='Total Sales')
ax.set_xlabel('Product')
ax.set_ylabel('Quarterly Sales')
ax2.set_ylabel('Total Annual Sales')
plt.title('Quarterly Sales and Total Annual Sales by Product')
ax.legend(loc='upper left')
ax2.legend(loc='upper right')
plt.show()
    """,
    """
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Load the dataset
heart_disease = pd.read_csv('https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data', header=None, na_values="?")
heart_disease.columns = ['age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach', 'exang', 'oldpeak', 'slope', 'ca', 'thal', 'target']

# Data preprocessing
heart_disease.dropna(inplace=True)  # Remove rows with missing values
Q1 = heart_disease['chol'].quantile(0.25)
Q3 = heart_disease['chol'].quantile(0.75)
IQR = Q3 - Q1
heart_disease = heart_disease[(heart_disease['chol'] >= (Q1 - 1.5 * IQR)) & (heart_disease['chol'] <= (Q3 + 1.5 * IQR))]

# Plotting
sns.boxplot(data=heart_disease, x='cp', y='chol')
plt.xlabel('Chest Pain Type')
plt.ylabel('Cholesterol Level')
plt.title('Cholesterol Levels by Chest Pain Type with Outliers Removed')
plt.show()
    """,
    """
import pandas as pd
import matplotlib.pyplot as plt

# Load the datasets
ratings = pd.read_csv('https://datasets.imdbws.com/title.ratings.tsv.gz', sep='\t', dtype={'tconst': 'str', 'averageRating': 'float64', 'numVotes': 'int64'})
movies = pd.read_csv('https://datasets.imdbws.com/title.basics.tsv.gz', sep='\t', dtype={'tconst': 'str', 'titleType': 'str', 'primaryTitle': 'str', 'originalTitle': 'str', 'isAdult': 'str', 'startYear': 'str', 'endYear': 'str', 'runtimeMinutes': 'str', 'genres': 'str'}, low_memory=False)

# Data preprocessing
movies['startYear'] = pd.to_numeric(movies['startYear'], errors='coerce')
ratings = ratings.merge(movies, on='tconst')
ratings = ratings[ratings['startYear'] > 0]  # Remove movies with invalid years
avg_ratings = ratings.groupby('startYear')['averageRating'].mean()

# Plotting
avg_ratings.plot(kind='bar', figsize=(20, 8), color='skyblue')
plt.xlabel('Year')
plt.ylabel('Average Rating')
plt.title('Average Movie Ratings by Year')
plt.show()
    """,
    """
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.datasets import load_diabetes

# Load the dataset
diabetes = load_diabetes()
diabetes_df = pd.DataFrame(diabetes.data, columns=diabetes.feature_names)
diabetes_df['target'] = diabetes.target

# Data preprocessing
diabetes_df['target_quartile'] = pd.qcut(diabetes_df['target'], q=4, labels=['Q1', 'Q2', 'Q3', 'Q4'])

# Plotting
sns.pairplot(data=diabetes_df, vars=['bmi', 'bp', 's1', 's2', 's3'], hue='target_quartile')
plt.suptitle('Pair Plot of Selected Features Color-coded by Target Quartiles', y=1.02)
plt.show()
    """
]


In [13]:
NUM_THREADS = 4

In [14]:
def makeDSPYExamples(graphs, data_descriptions, graph_descriptions):
    exampleList = []
    for i in range(len(graphs)):
        data_desc = data_descriptions[i]
        graph_desc = graph_descriptions[i]
        graph = graphs[i]
        exampleObj = dspy.Example(data_description=data_desc, graph_description=graph_desc, graph=graph).with_inputs("data_description", "graph_description")
        exampleList.append(exampleObj)
    return exampleList

class TextToDataManipulation(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_data_manipulation = dspy.ChainOfThought(TextToDataManipulationSignature)

    def forward(self, data_description, graph_description):
        data_manipulation_code = self.generate_data_manipulation(data_description=data_description, graph_description=graph_description)
        return data_manipulation_code

class TextToDataManipulationSignature(dspy.Signature):
    """Generate data manipulation code from data and target graph descriptions"""
    data_description = dspy.InputField()
    graph_description = dspy.InputField()
    data_manipulation_code = dspy.OutputField()

class TextToGraphCode(dspy.Module):
    def __init__(self):
        super().__init__()
        self.generate_graph_code = dspy.ChainOfThought(TextToGraphCodeSignature)

    def forward(self, graph_description, data_description, data_manipulation_code):
        graph_code = self.generate_graph_code(graph_description=graph_description, data_description=data_description, data_manipulation_code=data_manipulation_code)
        return graph_code

class TextToGraphCodeSignature(dspy.Signature):
    """Generate graph code from a given description and data"""
    graph_description = dspy.InputField()
    data_description = dspy.InputField()
    data_manipulation_code = dspy.InputField()
    graph_code = dspy.OutputField()

class Analyzer(dspy.Module):
    def __init__(self):
        super().__init__()
        self.text_to_data_manipulation = TextToDataManipulation()
        self.text_to_graph_code = TextToGraphCode()

    def forward(self, graph_description, data_description):
        # data manipulation code
        data_manipulation_code_prediction = self.text_to_data_manipulation(data_description=data_description, graph_description=graph_description)
        
        # get the actual code from the Prediction object
        data_manipulation_code = data_manipulation_code_prediction.data_manipulation_code
        
        # graph code
        graph_code_prediction = self.text_to_graph_code(graph_description=graph_description, data_description=data_description, data_manipulation_code=data_manipulation_code)
        
        # get the actual graph code from the Prediction object
        graph_code = graph_code_prediction.graph_code
        
        # combine the data manipulation code and the graph code
        complete_code = f"{data_manipulation_code}\n{graph_code}"
        # print(complete_code)
        
        return complete_code

class AnalyzerSignature(dspy.Signature):
    """Generate python code to create a graph for the given description"""
    data_description = dspy.InputField()
    graph_description = dspy.InputField()
    graph_code = dspy.OutputField()

eval_LLM = dspy.Together(model="deepseek-ai/deepseek-llm-67b-chat")

class GraphAssessment(dspy.Signature):
    """Assess the quality of a graph along the specified dimensions."""
    data_description = dspy.InputField()
    graph_description = dspy.InputField()
    complete_code = dspy.InputField()
    assessment_result = dspy.OutputField(desc="Assessment result with scores for different categories")

def metric(example, pred, trace=None):
    data_description, graph_description, complete_code = example.data_description, example.graph_description, pred

    big_problems_check = "Does the complete code preprocess the data correctly and plot the correct data with appropriate axes labels? Is the graph complete and meaningful? Are there any errors with the code?"
    intermediate_problems_check = "Is the graph type appropriate and are the axis scales optimal? Are there any other problems with the graph?"
    trivial_problems_check = "Does the graph have good stylistic choices like colors, and are the title and axis names appropriate?"

    predictor = dspy.Predict(GraphAssessment)

    with dspy.context(lm=eval_LLM):
        big_problems = predictor(data_description=data_description, graph_description=graph_description, complete_code=complete_code, assessment_question=big_problems_check)
        intermediate_problems = predictor(data_description=data_description, graph_description=graph_description, complete_code=complete_code, assessment_question=intermediate_problems_check)
        trivial_problems = predictor(data_description=data_description, graph_description=graph_description, complete_code=complete_code, assessment_question=trivial_problems_check)

    # LLM eval
    big_problems_score = 1.0 if big_problems.assessment_result.lower() == 'yes' else 0.5
    intermediate_problems_score = 0.5 if intermediate_problems.assessment_result.lower() == 'yes' else 0.25
    trivial_problems_score = 0.25 if trivial_problems.assessment_result.lower() == 'yes' else 0.1

    # score out of 1.0
    final_score = big_problems_score + intermediate_problems_score + trivial_problems_score
    final_score = min(1.0, final_score)

    return final_score

# it is challenging trying to create a perfect metric to compare to w/o LLMs
# use LLM but also give it the gold standard code. tell it what is a problem and what isn't. maybe even create a DSPy program for this instead?

# def metric(example, pred, trace=None):
    # should compare to gold standard

In [15]:
allExamples = makeDSPYExamples(graphs, data_descriptions, graph_descriptions)
trainset = allExamples[:7]
testset = allExamples[7:]

In [16]:
print(testset[0])

Example({'data_description': 'Fields: age (float), sex (integer), cp (integer), trestbps (integer), chol (integer), fbs (integer), restecg (integer), thalach (integer), exang (integer), oldpeak (float), slope (integer), ca (integer), thal (integer), target (integer). Format: CSV', 'graph_description': 'Box plot of cholesterol levels by chest pain type, with outlier detection and removal.', 'graph': '\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n# Load the dataset\nheart_disease = pd.read_csv(\'https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data\', header=None, na_values="?")\nheart_disease.columns = [\'age\', \'sex\', \'cp\', \'trestbps\', \'chol\', \'fbs\', \'restecg\', \'thalach\', \'exang\', \'oldpeak\', \'slope\', \'ca\', \'thal\', \'target\']\n\n# Data preprocessing\nheart_disease.dropna(inplace=True)  # Remove rows with missing values\nQ1 = heart_disease[\'chol\'].quantile(0.25)\nQ3 = heart_disease[\'

In [17]:
llama = dspy.Together(model="meta-llama/Meta-Llama-3-70B-Instruct-Turbo")
dspy.configure(lm=llama)

In [18]:
teleprompter = LabeledFewShot(k=7)
compiled = teleprompter.compile(Analyzer(), trainset=trainset)

evaluater = Evaluate(devset=testset, metric=metric, num_threads=NUM_THREADS, display_progress=True, display_table=0)
evaluater(compiled)



KeyboardInterrupt: 

In [None]:
# Initialize the analyzer module
analyzer = Analyzer()

data_description = data_descriptions[9]
graph_description = graph_descriptions[9]

# Generate the data manipulation code
data_manipulation_code = analyzer.text_to_data_manipulation(data_description=data_description, graph_description=graph_description)
print("Data Manipulation Code:")
print(data_manipulation_code)

data_manipulation_code = data_manipulation_code.data_manipulation_code

# Generate the graph code
graph_code = analyzer.text_to_graph_code(graph_description=graph_description, data_description=data_description, data_manipulation_code=data_manipulation_code)
print("\nGraph Code:")
print(graph_code)

# Generate the complete code
complete_code = analyzer(graph_description=graph_description, data_description=data_description)
print("\nComplete Code:")
print(complete_code)