In [None]:
import os
import glob
import argparse
from torch.utils.data import DataLoader, random_split
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from model import *
import json
import numpy as np
from ate import*

In [None]:
def load_target(file_path='/Users/asus/Desktop/SCM/target.csv'):
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
        if content.startswith('\ufeff'):  
            content = content[1:]

    from io import StringIO
    data = np.loadtxt(StringIO(content), delimiter=',')
    x=data[:, 2:]
    t, y,  = data[:, 0], data[:, 1][:, None]
    return x,t.reshape(-1, 1),y


In [None]:
def load_source(file_path='/Users/asus/Desktop/SCM/source.csv'):
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
        if content.startswith('\ufeff'): 
            content = content[1:]

    from io import StringIO
    data1 = np.loadtxt(StringIO(content), delimiter=',')
    x=data1[:, 2:]
    t, y,  = data1[:, 0], data1[:, 1][:, None]
    return x,t.reshape(-1, 1),y


In [None]:
def get_estimate(q_t0, q_t1, g, t, y_dragon, index, eps, truncate_level=0.01):
    """
    getting the back door adjustment & TMLE estimation
    """

    psi_n = psi_naive(q_t0, q_t1, g, t, y_dragon, truncate_level=truncate_level)
    ipw_n, dr_n = psi_weighting(q_t0, q_t1, g, t, y_dragon, truncate_level=truncate_level)
    psi_tmle, psi_tmle_std, eps_hat, initial_loss, final_loss, g_loss = psi_tmle_cont_outcome(q_t0, q_t1, g, t,
                                                                                              y_dragon,
                                                                                              truncate_level=truncate_level)
    return psi_n, psi_tmle, initial_loss, final_loss, g_loss,ipw_n, dr_n

In [None]:
class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

In [None]:
def _split_output(yt_hat, t, y, y_scaler, x, index):
    """     yt_hat: Generated prediction
            t: Binary treatment assignments
            y: Treatment outcomes
            y_scaler: Scaled treatment outcomes
            x: Covariates
    """
    yt_hat = yt_hat.detach().cpu().numpy()
    q_t0 = y_scaler.inverse_transform(yt_hat[:, 0].reshape(-1, 1).copy())
    q_t1 = y_scaler.inverse_transform(yt_hat[:, 1].reshape(-1, 1).copy())
    g = yt_hat[:, 2].copy()

    if yt_hat.shape[1] == 4:
        eps = yt_hat[:, 3][0]
    else:
        eps = np.zeros_like(yt_hat[:, 2])

    y = y_scaler.inverse_transform(y.copy())
    var = "average propensity for treated: {} and untreated: {}".format(g[t.squeeze() == 1.].mean(),
                                                                        g[t.squeeze() == 0.].mean())
    print(var)

    return {'q_t0': q_t0, 'q_t1': q_t1, 'g': g, 't': t, 'y': y, 'x': x, 'index': index, 'eps': eps}

In [None]:
def train(train_loader, net, optimizer, criterion,valid_loader= None,l1_reg = None):

    avg_loss = 0

    for i, data in enumerate(train_loader):
  
        inputs, labels = data


        optimizer.zero_grad()


        outputs = net(inputs)
        loss = criterion(outputs, labels)
        if l1_reg is not None:
            l1_penalty = l1_reg * sum([p.abs().sum() for p in net.parameters()])
            loss+= l1_penalty
        loss.backward()
        optimizer.step()

        avg_loss += loss

    valid_loss = None
    if valid_loader is not None:
        valid_loss = 0.0
        net.eval()     
        for data, labels in valid_loader:
            if torch.cuda.is_available():
                data, labels = data.cuda(), labels.cuda()
            
            target = net(data)
            loss = criterion(target,labels)
            if l1_reg is not None:
                loss+= l1_reg * sum([p.abs().sum() for p in net.parameters()]) 
            valid_loss += loss
        valid_loss = valid_loss/len(valid_loader)
    return avg_loss / len(train_loader), valid_loss

In [None]:
def train_and_predict_dragons(t, y_unscaled, x, net,seed = 0, targeted_regularization=True, output_dir='',
                              knob_loss=dragonnet_loss_binarycross, ratio=1., dragon='', val_split=0.2, batch_size=64,lr =1e-3,l1_reg = None):
  
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(device)

    verbose = 0
    y_scaler = StandardScaler()
    y = y_scaler.fit_transform(y_unscaled)
    train_outputs = []
    test_outputs = []

    if targeted_regularization:
        loss = make_tarreg_loss(ratio=ratio, dragonnet_loss=knob_loss)
    else:
        loss = knob_loss

    i = seed
    torch.manual_seed(i)
    np.random.seed(i)
    random.seed(i)

    if ratio == 0:
        train_index = np.arange(x.shape[0])
        test_index = train_index
    else:
        train_index, test_index = train_test_split(np.arange(x.shape[0]), test_size=ratio, random_state=seed)
        print(f'test_index {test_index}')
   
    x_train, x_test = x[train_index], x[test_index]
    y_train, y_test = y[train_index], y[test_index]
    t_train, t_test = t[train_index], t[test_index]

    yt_train = np.concatenate([y_train, t_train], 1)

    yt_test = np.concatenate([y_test, t_test], 1)

    tensors_train = torch.from_numpy(x_train).float().to(device), torch.from_numpy(yt_train).float().to(device)
    train_size = int((val_split) * len(TensorDataset(*tensors_train)))
    val_size = int(len(TensorDataset(*tensors_train))-train_size)
    train_set, valid_set = random_split(TensorDataset(*tensors_train),[train_size,val_size])
    train_loader = DataLoader(train_set, batch_size=batch_size)
    valid_loader = DataLoader(valid_set, batch_size=500)

    import time;
    start_time = time.time()

    epochs1 = 100
    epochs2 = 300

    optimizer_Adam = optim.Adam([{'params': net.representation_block.parameters()},
                                 {'params': net.t_predictions.parameters()},
                                 {'params': net.t0_head.parameters(), 'weight_decay': 0.01},
                                 {'params': net.t1_head.parameters(), 'weight_decay': 0.01}], lr=lr)
    optimizer_SGD = optim.SGD([{'params': net.representation_block.parameters()},
                               {'params': net.t_predictions.parameters()},
                               {'params': net.t0_head.parameters(), 'weight_decay': 0.01},
                               {'params': net.t1_head.parameters(), 'weight_decay': 0.01}], lr=lr*0.01, momentum=0.9)
    scheduler_Adam = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_Adam, mode='min', factor=0.5, patience=5,
                                                          threshold=1e-8, cooldown=0, min_lr=0)
    scheduler_SGD = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_SGD, mode='min', factor=0.5, patience=5,
                                                         threshold=0, cooldown=0, min_lr=0)

    train_loss = 0

    early_stopper = EarlyStopper(patience=2, min_delta=0.)

    # Adam training run
    for epoch in range(epochs1):

        # Train on data
        train_loss,val_loss = train(train_loader, net, optimizer_Adam, loss,valid_loader = valid_loader,l1_reg = l1_reg)
        
        if early_stopper.early_stop(val_loss):             
            break

        scheduler_Adam.step(val_loss)

    print(f"Adam loss: train -- {train_loss}, validation -- {val_loss}, epoch {epoch}")

    # SGD training run
    
    early_stopper = EarlyStopper(patience=40, min_delta=0.)

    for epoch in range(epochs2):
        # Train on data
        train_loss,val_loss = train(train_loader, net, optimizer_SGD, loss,valid_loader = valid_loader,l1_reg = l1_reg)

        if early_stopper.early_stop(val_loss):             
            break
        scheduler_SGD.step(val_loss)
        

    print(f"SGD loss: train --  {train_loss}, validation -- {val_loss},  epoch {epoch}")

    elapsed_time = time.time() - start_time
    print("***************************** elapsed_time is: ", elapsed_time)

    yt_hat_test = net(torch.from_numpy(x_test).float().to(device))
    yt_hat_train = net(torch.from_numpy(x_train).float().to(device))

    test_outputs += [_split_output(yt_hat_test, t_test, y_test, y_scaler, x_test, test_index)]
    train_outputs += [_split_output(yt_hat_train, t_train, y_train, y_scaler, x_train, train_index)]
   
    train_all_dicts = _split_output(yt_hat_train, t_train, y_train, y_scaler, x_train, train_index)
    test_all_dicts = _split_output(yt_hat_test, t_test, y_test, y_scaler, x_test, test_index)
 
    psi_n, psi_tmle, initial_loss, final_loss, g_loss,ipw_n, dr_n = get_estimate(train_all_dicts['q_t0'].reshape(-1, 1), train_all_dicts['q_t1'].reshape(-1, 1), train_all_dicts['g'].reshape(-1, 1), train_all_dicts['t'].reshape(-1, 1), train_all_dicts['y'].reshape(-1, 1), train_all_dicts['index'].reshape(-1, 1), train_all_dicts['eps'].reshape(-1, 1),truncate_level=0.01)

    train_dict = {'psi_n':psi_n, 'classification_mse': g_loss,'ipw_n':ipw_n, 'dr_n':dr_n,'regression_loss':regression_loss(torch.tensor(yt_train).to(device),yt_hat_train).cpu().detach(),'BCE':binary_classification_loss(torch.tensor(yt_train).float().to(device),yt_hat_train).cpu().detach().numpy(),'regression_mse':initial_loss,'index':train_all_dicts['index']}
    
    psi_n, psi_tmle, initial_loss, final_loss, g_loss,ipw_n, dr_n = get_estimate(test_all_dicts['q_t0'].reshape(-1, 1), test_all_dicts['q_t1'].reshape(-1, 1), test_all_dicts['g'].reshape(-1, 1), test_all_dicts['t'].reshape(-1, 1), test_all_dicts['y'].reshape(-1, 1), test_all_dicts['index'].reshape(-1, 1), test_all_dicts['eps'].reshape(-1, 1),truncate_level=0.01)
    return test_outputs, train_outputs, net,train_dict,test_dict

In [None]:
import torch

def convert_to_serializable(obj):
    """Recursively convert non-serializable objects into a JSON-serializable form"""
    if isinstance(obj, dict):
        return {k: convert_to_serializable(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_serializable(i) for i in obj]
    elif isinstance(obj, torch.Tensor):  
        return obj.tolist() if obj.dim() > 0 else obj.item()
    elif isinstance(obj, np.ndarray): 
        return obj.tolist()
    elif isinstance(obj, (np.float32, np.float64)): 
        return float(obj)
    elif isinstance(obj, (np.int32, np.int64)):  
        return int(obj)
    else:
        return obj

In [None]:
def run_SCM(data_base_dir='/Users/asus/Desktop/SCM', output_dir='/Users/asus/Desktop/SCM',
             knob_loss=dragonnet_loss_binarycross, ratio=1., dragon='', lr2=1e-3, l1_reg=1e-3, batchsize2=64):
    print("the dragon is {}".format(dragon))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    x_t, t_t, y_t = load_target()
    x_s, t_s, y_s = load_source()

    final_output = []
    for is_targeted_regularization in [False]:
        print("Is targeted regularization: {}".format(is_targeted_regularization))
        net = TarNet(x_s.shape[1]).to(device) if dragon == 'tarnet' else DragonNet(x_s.shape[1]).to(device)
        
        # Train source data
        _, _, net, _, _ = train_and_predict_dragons(t_s, y_s, x_s, net, seed=42,
                                                    targeted_regularization=is_targeted_regularization,
                                                    knob_loss=knob_loss, ratio=0, dragon=dragon,
                                                    val_split=0.3, batch_size=64, lr=1e-3)
        
        parm = {}
        for name, param in net.named_parameters():
            param.grad = None
            parm[name]=param.detach()

        if dragon == 'tarnet':
            print('I am here making tarnet')
            net = TarNet_transfer(x_s.shape[1],parm).to(device)

        elif dragon == 'dragonnet':
            print("I am here making dragonnet")
            net = DragonNet_transfer(x_s.shape[1],parm).to(device)
        
        # Train target data
        test_outputs, train_output, net, train_dict, test_dict = train_and_predict_dragons(
            t_t, y_t, x_t, net, seed=42, targeted_regularization=is_targeted_regularization,
            knob_loss=knob_loss, ratio=0.5, dragon=dragon, val_split=0.3, batch_size=batchsize2, lr=lr2, l1_reg=l1_reg)
        
        # Calculate errors
        for result_dict in [train_dict, test_dict]:
            truth = 2
            result_dict['index'] = result_dict['index'].tolist()
            result_dict['err'] = abs(truth - result_dict['psi_n']).mean()
            result_dict['dr_err'] = abs(truth - result_dict['dr_n']).mean()
            result_dict['ipw_error'] = abs(truth - result_dict['ipw_n']).mean()
        train_dict = {f'{k}_train': v.item() if 'index' not in k else v for k, v in train_dict.items()}
        test_dict = {f'{k}_test': v.item() if 'index' not in k else v for k, v in test_dict.items()}
        train_dict = {**train_dict,**test_dict}
        

        final_output.append(train_dict)
    
    # Save results
    serializable_output = convert_to_serializable(final_output)
    if not os.path.exists(f'./TCL-params/'):
        os.makedirs(f'./TCL-params/')
    with open(f'./TCL-params/TCL-experiments_transfer_{dragon}.json', 'w') as fp:
        json.dump(serializable_output, fp, indent=2)

In [None]:
def turn_knob(data_base_dir='/Users/asus/Desktop/SCM/', knob='dragonnet',
              output_base_dir='',lr  = 1e-3, l1reg = 1e-4,batchsize = 16):
    output_dir = os.path.join(output_base_dir, knob)#扩充output_dir

    if knob == 'dragonnet':
        run_SCM(data_base_dir=data_base_dir, output_dir=output_dir, dragon='dragonnet' ,lr2  = lr ,l1_reg = l1reg, batchsize2 = batchsize)

    if knob == 'tarnet':
        run_SCM(data_base_dir=data_base_dir, output_dir=output_dir, dragon='tarnet',lr2  = lr ,l1_reg = l1reg, batchsize2 = batchsize)

In [None]:
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_base_dir', type=str, help="path to directory LBIDD", default="/Users/asus/Desktop/SCM")
    parser.add_argument('--knob', type=str, default='tarnet',
                        help="dragonnet or tarnet")

    parser.add_argument('--output_base_dir', type=str, help="directory to save the output",default="/Users/asus/Desktop/SCM")

    parser.add_argument('--transfer_lr',type = float,default=0.01)

    parser.add_argument('--l1reg',type = float,default=0.1)

    parser.add_argument('--batchsize',type = int,default=64)
    
    args, unknown = parser.parse_known_args()
    turn_knob(args.data_base_dir, args.knob, args.output_base_dir,args.transfer_lr, args.l1reg,args.batchsize)


if __name__ == '__main__':
    main()