In [None]:
#from dotenv import load_dotenv
import guidance
import os

#load_dotenv()

api_key = os.getenv("OPENAI_API_KEY")
endpoint = "https://msri-openai-ifaq.azure-api.net"
llm = guidance.models.AzureOpenAI(
    model="gpt-4",
    azure_endpoint=endpoint,
    api_key = api_key,
    version="2023-03-15-preview"
    )

In [None]:
from typing import Dict, Tuple, List

sea_ice_variables = [
    "geopotential_heights", 
    "relative_humidity", 
    "sea_level_pressure",  
    "zonal_wind_at_10_meters", 
    "meridional_wind_at_10_meters", 
    "sensible_plus_latent_heat_flux", 
    "total_precipitation", 
    "total_cloud_cover", 
    "total_cloud_water_path",  
    "surface_net_shortwave_flux", 
    "surface_net_longwave_flux", 
    "northern_hemisphere_sea_ice_extent",
]

treatment = "surface_net_longwave_flux"
outcome = "northern_hemisphere_sea_ice_extent"

# ground truth confounders to the relationship between surface_net_longwave_flux and northern_hemisphere_sea_ice_extent
sea_ice_confounders = ["total_precipitation"]
   
sea_ice_relationships: List[Tuple[str, str]] = [
    ("surface_net_longwave_flux", "northern_hemisphere_sea_ice_extent"), 
    
    ("geopotential_heights", "surface_net_longwave_flux"), 
    ("geopotential_heights", "relative_humidity"), 
    ("geopotential_heights", "sea_level_pressure"), 

    ("relative_humidity", "total_cloud_cover"), 
    ("relative_humidity", "total_cloud_water_path"),
    ("relative_humidity", "total_precipitation"), 
    ("relative_humidity", "surface_net_longwave_flux"),

    ("sea_level_pressure", "relative_humidity"), 
    ("sea_level_pressure", "geopotential_heights"), 
    ("sea_level_pressure", "zonal_wind_at_10_meters"), 
    ("sea_level_pressure", "northern_hemisphere_sea_ice_extent"), 
    ("sea_level_pressure", "sensible_plus_latent_heat_flux"), 
    ("sea_level_pressure", "meridional_wind_at_10_meters"),

    ("zonal_wind_at_10_meters", "northern_hemisphere_sea_ice_extent"),
    ("zonal_wind_at_10_meters", "sensible_plus_latent_heat_flux"), 
    
    ("meridional_wind_at_10_meters", "northern_hemisphere_sea_ice_extent"),
    ("meridional_wind_at_10_meters", "sensible_plus_latent_heat_flux"), 
   
    ("sensible_plus_latent_heat_flux", "northern_hemisphere_sea_ice_extent"), 
    ("sensible_plus_latent_heat_flux", "sea_level_pressure"), 
    ("sensible_plus_latent_heat_flux", "zonal_wind_at_10_meters"), 
    ("sensible_plus_latent_heat_flux", "meridional_wind_at_10_meters"), 
    ("sensible_plus_latent_heat_flux", "total_precipitation"), 
    ("sensible_plus_latent_heat_flux", "total_cloud_cover"), 
    ("sensible_plus_latent_heat_flux", "total_cloud_water_path"), 
    
    ("total_precipitation", "northern_hemisphere_sea_ice_extent"),
    ("total_precipitation", "relative_humidity"),
    ("total_precipitation", "sensible_plus_latent_heat_flux"),
    ("total_precipitation", "surface_net_longwave_flux"),
    ("total_precipitation", "total_cloud_cover"),
    ("total_precipitation", "total_cloud_water_path"),
   
    ("total_cloud_water_path", "total_precipitation"), 
    ("total_cloud_water_path", "sensible_plus_latent_heat_flux"), 
    ("total_cloud_water_path", "relative_humidity"), 
    ("total_cloud_water_path", "surface_net_longwave_flux"), 
    ("total_cloud_water_path", "surface_net_shortwave_flux"), 
    
    ("total_cloud_cover", "total_precipitation"),
    ("total_cloud_cover", "sensible_plus_latent_heat_flux"),
    ("total_cloud_cover", "relative_humidity"),
    ("total_cloud_cover", "surface_net_longwave_flux"),
    ("total_cloud_cover", "surface_net_shortwave_flux"), 
    
    ("surface_net_shortwave_flux", "northern_hemisphere_sea_ice_extent"),
    
    ("northern_hemisphere_sea_ice_extent", "sea_level_pressure"),
    ("northern_hemisphere_sea_ice_extent", "zonal_wind_at_10_meters"),
    ("northern_hemisphere_sea_ice_extent", "meridional_wind_at_10_meters"),
    ("northern_hemisphere_sea_ice_extent", "sensible_plus_latent_heat_flux"),
    ("northern_hemisphere_sea_ice_extent", "surface_net_shortwave_flux"),
    ("northern_hemisphere_sea_ice_extent", "surface_net_longwave_flux"),
]

## Helpers

Model type - the type of LLM used
By default it's set to completions models

Relationship strategy - is the type of request made to the LLM (request parent, child, pairwise relationship)

In [None]:
#from suggesters import ModelType, RelationshipStrategy
#model_type = ModelType.Completion
#relationship_strategy = RelationshipStrategy.Parent

## Model

In [None]:
import pywhyllm
m = pywhyllm.SimpleModelSuggester()

"""returns a dictionary with the how many times that edge was suggested"""
model_edges = m.suggest_pairwise_relationship(llm, treatment, outcome)

"""returns a dictionary with how many times a confounder/edge with confounder was suggested"""
"""suggest_relationships calls suggest_confounders"""
relationships = m.suggest_relationships(llm, sea_ice_variables[:2])


In [None]:
confounders = m.suggest_confounders(llm, variables=sea_ice_variables, treatment=treatment, outcome=outcome)

In [None]:
confounders

## Identifier

In [None]:
from suggesters import IdentifierSuggester
i = IdentifierSuggester()

"""calls modeler suggest_confounders in the background"""
backdoor = i.suggest_backdoor(treatment=treatment, outcome=outcome, factors_list=sea_ice_variables, llm=llm)

"""suggests mediators"""
front_door = i.suggest_frontdoor(treatment=treatment, outcome=outcome, factors_list=sea_ice_variables, llm=llm)

"""suggests instrumental variables"""
ivs = i.suggest_ivs(treatment=treatment, outcome=outcome, factors_list=sea_ice_variables, llm=llm)

## Validator

In [None]:
from suggesters import ValidationSuggester
v = ValidationSuggester()

"""suggets latent confounders"""
latent_confounders = v.suggest_latent_confounders(treatment=treatment, outcome=outcome, factors_list=sea_ice_variables, llm=llm)

"""suggests negative controls"""
negative_controls = v.suggest_negative_controls(treatment=treatment, outcome=outcome, factors_list=sea_ice_variables, llm=llm)

parent=RelationshipStrategy.Parent
child=RelationshipStrategy.Child
pairwise=RelationshipStrategy.Pairwise

"""Relationship strategy for choosing how to get the model critiqued"""
edges, critiqued_edges = v.critique_graph(edges=model_edges, treatment=treatment, outcome=outcome, factors_list=sea_ice_variables, llm=llm, relationship_strategy=parent)