In [1]:
import os, sys
import pandas as pd
import torch
print(torch.version.cuda)

# 检测运行环境
def in_notebook():
    return 'IPKernelApp' in getattr(globals().get('get_ipython', lambda: None)(), 'config', {})

if in_notebook():
    notebook_dir = os.getcwd()
    src_path = os.path.abspath(os.path.join(notebook_dir, '..'))
    RUN_MODE = 'train'
else:
    src_path = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
sys.path.append(src_path) if src_path not in sys.path else None

from src.utils import *
from src.model_utils import *
from src.setup import *
from ite_setup import *
from metrics import *
from ganite_mod import Ganite
from ganite_mod.datasets import load
from ganite_mod.utils.metrics import sqrt_PEHE_with_diff

12.1


In [2]:
df = pd.read_csv(f'{DATA}/imputed/EXIT_SEP_clean_imputed.tsv.gz', sep='\t', index_col='ID')
features, _, _, treatment, outcomes = get_ite_features()
current_outcome = outcomes[0] # 设置预测目标

df_train = df.sample(frac=0.7, random_state=19960816)
df_test = df[~df.index.isin(df_train.index)].copy()
X, W, Y = load_data(df)
X_train, W_train, Y_train = load_data(df_train)
X_test, W_test, Y_test = load_data(df_test)

In [None]:
# modified GANITE
model = Ganite(dim_in=X.shape[1],
               binary_y=True,
               dim_hidden=300,
               alpha = 0.3,
               beta = 0.3,
               depth = 3,
               minibatch_size = 200,
               num_iterations=2500,
               num_discr_iterations=3,
               )


if RUN_MODE == 'train':
    model = model.fit(X_train, W_train, Y_train)
    torch.save(model.state_dict(), f"{MODELS}/GANITE.pth")
else:
    model.load_state_dict(torch.load(f"{MODELS}/GANITE_best.pth", weights_only=True))
    model.eval()  # 切换到评估模式（重要！）
    print("模型参数已加载！")

In [5]:
# 测试集测试
Y_1_test, Y_0_test, ITE_test = model(X_test)
df_test['potential_y1'] = Y_1_test.cpu()
df_test['potential_y0'] = Y_0_test.cpu()
df_test['ITE'] = ITE_test.cpu()
df_test['y_pred_observed'] = df_test.apply(lambda row: row['potential_y1'] if row[treatment]==1 else row['potential_y0'], axis=1)

ATE_test = RCT_ATE(df_test[treatment], df_test[current_outcome])
ATE_pred_ob = RCT_ATE(df_test[treatment], df_test['y_pred_observed'])
ATE_pred = df_test['ITE'].mean()

print(f'实际ATE: {ATE_test:.4f}, 预测实际ATE: {ATE_pred_ob:.4f}, ATE误差: {ATE_test - ATE_pred_ob:.4f}, 预测组间ATE: {ATE_pred:.4f}')

实际ATE: -0.0357, 预测实际ATE: -0.0872, ATE误差: 0.0514, 预测组间ATE: -0.0820


In [None]:
'实际ATE: -0.0357, 预测实际ATE: -0.0515, ATE误差: 0.0158, 预测组间ATE: -0.0356'