In [13]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
import pdb
plt.rcParams.update({'font.size': 18})

In [14]:
def plot_df(df, filename_base, z_name):
    unique_combinations = df[['text_col', 'lm_library', 'lm_name']].drop_duplicates()
    min_Cy = df.Cy.values.min()
    max_Cy = df.Cy.values.max()
    min_Cd = df.Cd.values.min()
    max_Cd = df.Cd.values.max()
    # Loop through each unique combination and subset the DataFrame
    for _, row in unique_combinations.iterrows():
        subset = df[
            (df['text_col'] == row['text_col']) &
            (df['lm_library'] == row['lm_library']) &
            (df['lm_name'] == row['lm_name'])
        ]
        if not subset.empty:
            try:
                df_pivot = subset.pivot(index='Cy', columns='Cd', values=z_name)
            except:
                # print(filename_base)
                # print(subset)
                print()

            # Create X, Y, Z values for contour plot
            Cy_grid = df_pivot.index.values  # X values
            Cd_grid = df_pivot.columns.values  # Y values
            Z = df_pivot.values  # Z values (ovb)

            # Create meshgrid for X and Y
            Cy_grid, Cd_grid = np.meshgrid(Cy_grid, Cd_grid)

            plt.figure(figsize=(8, 6))
            contour = plt.contour(Cy_grid, Cd_grid, Z, levels=10, cmap='viridis')
            contour_zero = plt.contour(Cy_grid, Cd_grid, Z, levels=[0], colors='red', linewidths=2)
            plt.clabel(contour, inline=True, fontsize=14)
            plt.clabel(contour_zero, inline=True, fontsize=14)
            plt.scatter(0, 0, color='black', marker='^', s=100, zorder=5)  # Add a triangle at (0, 0)
            plt.text(0, 0, 'unadjusted', color='black', fontsize=18, ha='left', va='bottom')
            plt.xlim(min_Cy - 0.01, max_Cy)
            plt.ylim(min_Cd - 0.01, max_Cd)
            # plt.colorbar(contour)
            plt.xlabel('Cy', fontsize=18)
            plt.ylabel('Cd', fontsize=18)
            # plt.title('{} ({}-{}-{})'.format(z_name, row['text_col'],
            #                                                 row['lm_library'],
            #                                                 row['lm_name']))
            if (row['text_col'] == 'comment') and (row['lm_library'] == 'sentecon_empath') and ('gbm' in filename_base) and ('type1' in filename_base):
                print('extra')
                pt_names = ['healing', 'movement', 'exercise', 'science', 'masked']
                alignment = ['left', 'left', 'left', 'left', 'right']
                pts = [(0.194, 0.057), (0.184, 0.265), (0.187, 0.078), (0.191, 0.164), (0.482, 0.203)]
                for i in range(len(pt_names)):
                    
                    plt.scatter(*pts[i], color='red', marker='^', s=100, zorder=5)  # Add a triangle at (0, 0)
                    plt.text(*pts[i], pt_names[i], color='black', fontsize=18, ha=alignment[i], va='bottom')

                # plt.scatter(0.752, 0.227, color='red', marker='^', s=100, zorder=5)  # Add a triangle at (0, 0)
                # plt.text(0.752, 0.227, 'llm_prompt_features', color='black', fontsize=10, ha='right', va='bottom', zorder=10)
                # plt.xlim(min_Cy - 0.01, 0.5)
                # plt.ylim(min_Cd - 0.01, 0.5)
                # plt.show()
            plt.savefig('./plots/tirzepatide/contour/{}_{}-{}-{}_{}.png'.format(
                z_name, row['text_col'], row['lm_library'], row['lm_name'], filename_base)
            )
            plt.close()
            # plt.show()

In [15]:
results_dir = './results/tirzepatide/text_covs/norm_none'
for filename in os.listdir(results_dir):
    if 'CyCdgrid' in filename:
        df = pd.read_csv(os.path.join(results_dir, filename))
        filename_base = filename.split('.csv')[0]
        # print(filename_base)
        filename_base = filename_base.split('all_reps_RV0bound_')[1]
        print(filename_base)
        plot_df(df, filename_base, 'theta_minus')

treatgbm_outgbm_type2_nudr_no_ood_cv5_CyCdgrid
treatgbm_outgbm_type3_nudr_no_ood_cv5_CyCdgrid
treatmlp_outmlp_type2_nudr_no_ood_trunc0.1_cv5_CyCdgrid
treatgbm_outgbm_type3_nudr_no_ood_trunc0.1_cv5_CyCdgrid
treatmlp_outmlp_type3_nudr_no_ood_trunc0.1_cv5_CyCdgrid
treatgbm_outgbm_type2_nudr_no_ood_trunc0.1_cv5_CyCdgrid
treatgbm_outgbm_type1_nudr_no_ood_trunc0.1_cv5_CyCdgrid

extra
treatgbm_outgbm_type1_nudr_no_ood_cv5_CyCdgrid
extra
treatmlp_outmlp_type1_nudr_no_ood_trunc0.1_cv5_CyCdgrid
type1_nudr_no_ood_cv5_CyCdgrid
type1_nudr_no_ood_trunc0.1_cv5_CyCdgrid
type3_nudr_no_ood_trunc0.1_cv5_CyCdgrid
type2_nudr_no_ood_trunc0.1_cv5_CyCdgrid
treatmlp_outmlp_type1_nudr_no_ood_cv5_CyCdgrid
treatmlp_outmlp_type2_nudr_no_ood_cv5_CyCdgrid
type3_nudr_no_ood_cv5_CyCdgrid
treatmlp_outmlp_type3_nudr_no_ood_cv5_CyCdgrid
type2_nudr_no_ood_cv5_CyCdgrid
