In [None]:
import weave
import pandas as pd
from ydata_profiling import ProfileReport

from ydnpd.agent.core import CasualModelingAgentMachine, LLMSession
from ydnpd.agent.specifications import SPECIFICATION_V0
from ydnpd.agent.data_config import CENSUS_DATASET_METADATA, OLD_CENSUS_DATASET_METADATA
from ydnpd.agent.utils import sample_dataset, metadata_to_pandera_schema

RANDOM_STATE = 42

weave.init("data_gen_agent")

def produce_dataset(num_samples):
    try:

        llm_sess = LLMSession(
            specification=SPECIFICATION_V0,
            metadata=CENSUS_DATASET_METADATA)

        _ = CasualModelingAgentMachine(llm_sess)

        pandera_schema = metadata_to_pandera_schema(CENSUS_DATASET_METADATA["schema"])

        df = sample_dataset(llm_sess.context["model"], num_samples, pandera_schema)
        
        code = llm_sess.context["code"]

        error = None

    except Exception as e:
        df, code, error = None, None, e

    return df, code, error


def produce_mixture_dataset(num_samples, num_datasets, random_state=RANDOM_STATE):
    dfs = []
    codes = []
    errors = []

    while len(dfs) < num_datasets:
        print(len(dfs))

        df, code, error = produce_dataset(num_samples)
        if error is None:
            dfs.append(df)
            codes.append(code)
        else:
            errors.append(error)    
            print(error)

    mixture_df = (pd.concat(dfs)
                  .sample(num_samples, replace=False, random_state=random_state)
                  .reset_index(drop=True))

    return mixture_df, (dfs, codes, errors)

In [None]:
mixture_df, details = produce_mixture_dataset(23_006, 1)

In [None]:
# mixture_df.to_csv("acs-mix.csv")

In [None]:
# import sys
# import os
# os.chdir("..")
from ydnpd import load_dataset
national_df = load_dataset("acs/national")[0]

In [None]:
(ProfileReport(national_df, title="National")
 .compare(
     ProfileReport(mixture_df, title="Agent")
 )
)