In [1]:
from toy_classify import *
import ot
import os
import warnings
warnings.filterwarnings('ignore')

## pretrained AE

In [2]:
X_train_pre, y_train_pre, regions_train = generate_imbalanced_data(5000,5000, seed = 0)

In [None]:
X_tensor = torch.tensor(X_train_pre, dtype=torch.float32)
set_seed(0)
# Create a DataLoader
# Initialize and train the autoencoder
autoencoder = Autoencoder(input_dim=X_tensor.shape[1], latent_dim=3, hidden_dim=1024,num_layer=5)
autoencoder = train_autoencoder(autoencoder, X_tensor, num_epochs=2000, learning_rate=1e-4, device='cuda')


In [4]:
torch.save(autoencoder.state_dict(), f"saved_models/ae.pt")

## CoDSA with pretrained AE

In [None]:
list_result=[] # test and validation result for each hyperparameter combination
ce_tf=[] # cross entropy for transfer CoDSA
list_origin=[] # cross entropy for base model
W_list=[[] for j in range(10)] #W distance

n_minority= 800
n_majority = 3200
n_val=400
n_test=800
n_train= 4000

#hyperparameter sets
ratio_list = [x/10 for x in list(range(1,11))] #split ratio r
candidate_m_syn=  [int(n_train*round(x * 0.1, 2)) for x in range(21)] #synthetic sample size m/n
candidate_alpha_scale= [0.1, 0.2, 0.3,0.4, 0.5,0.6, 0.7,0.8, 0.9] #alpha
m=2*max(candidate_m_syn)

num_sim=10
num_split= len(ratio_list)



#Simulated data
X_total, y_total, regions_total = generate_imbalanced_data(
        n_minority + n_val + n_test, 
        n_majority + n_val + n_test, 
        seed=0
    )
y_total = y_total.reshape(-1, 1)    


#Create directories for saving models and synthetic data.
if not os.path.exists("saved_models_tf"):
    os.makedirs("saved_models_tf")
if not os.path.exists("synthetic_tf"):
    os.makedirs("synthetic_tf")
    
for j in range(10):
    set_seed(j)
    
    # Separate the minority and majority groups based on region labels.
    # Here, we assume that region label 0 indicates minority and 1 indicates majority.
    minority_idx = np.where(regions_total == 0)[0]
    majority_idx = np.where(regions_total == 1)[0]

    # Shuffle indices separately
    np.random.shuffle(minority_idx)
    np.random.shuffle(majority_idx)

    # For minority:
    min_train_idx = minority_idx[:n_minority]
    min_val_idx   = minority_idx[n_minority : n_minority + n_val]
    min_test_idx  = minority_idx[n_minority + n_val : n_minority + n_val + n_test]

    # For majority:
    maj_train_idx = majority_idx[:n_majority]
    maj_val_idx   = majority_idx[n_majority : n_majority + n_val]
    maj_test_idx  = majority_idx[n_majority + n_val : n_majority + n_val + n_test]

    # Combine directly for each set
    X_train_orig = np.vstack((X_total[min_train_idx], X_total[maj_train_idx]))
    y_train_orig = np.vstack((y_total[min_train_idx], y_total[maj_train_idx]))
    regions_train = np.concatenate((regions_total[min_train_idx], regions_total[maj_train_idx]))

    X_val = np.vstack((X_total[min_val_idx], X_total[maj_val_idx]))
    y_val = np.vstack((y_total[min_val_idx], y_total[maj_val_idx]))
    regions_val = np.concatenate((regions_total[min_val_idx], regions_total[maj_val_idx]))

    X_test = np.vstack((X_total[min_test_idx], X_total[maj_test_idx]))
    y_test = np.vstack((y_total[min_test_idx], y_total[maj_test_idx]))
    regions_test = np.concatenate((regions_total[min_test_idx], regions_total[maj_test_idx]))

    

    # Combine training X and y for diffusion training
    XY_train_orig = combine_XY(X_train_orig, y_train_orig)
    
    
    origin_model, mse_val_origin, mse_test_origin, y_val_opred, y_test_opred = train_and_evaluate(
    X_train_orig, y_train_orig, X_val, y_val, X_test, y_test, input_dim= X_train_orig.shape[1], batch_size = 1024,patience =20)
    
    list_origin.append([mse_val_origin, mse_test_origin])
    print([mse_val_origin, mse_test_origin])

    all_result=[]
    for split_ratio in ratio_list:

        n_diff1 = int(split_ratio * n_minority)
        n_diff2 = int(split_ratio * n_majority)

        # Shuffle indices for each group.
        indices_min = np.arange(n_minority)
        np.random.shuffle(indices_min)
        indices_maj = np.arange(n_majority)
        np.random.shuffle(indices_maj)

        # Extract diffusion subset using the shuffled indices.
        XY_diff = np.vstack((
            XY_train_orig[indices_min[:n_diff1], :],
            XY_train_orig[n_minority + indices_maj[:n_diff2], :]
        ))
        regions_diff = np.concatenate((
            regions_train[indices_min[:n_diff1]],
            regions_train[n_minority + indices_maj[:n_diff2]]
        ))
        
        if split_ratio==1:
            XY_reg=None
        else:
            XY_reg = np.vstack((XY_train_orig[n_diff1:n_minority,:],XY_train_orig[(n_minority+n_diff2):,:]))
            
            
            
        result=[]
        
        file_path = f"synthetic_tf/synthetic_X_seed{j}_ratio{split_ratio:.1f}.npy"
        if os.path.exists(file_path):          

            # Build file paths.
            file_X = f"synthetic_tf/synthetic_X_seed{j}_ratio{split_ratio:.1f}.npy"
            file_y = f"synthetic_tf/synthetic_y_seed{j}_ratio{split_ratio:.1f}.npy"

            # Load the synthetic data.
            X_syn_full = np.load(file_X)
            y_syn_full = np.load(file_y)
            
        else:
            with torch.no_grad():        
                U_diff=autoencoder.encoder(torch.tensor(XY_diff[:,:-1], dtype=torch.float32).to('cuda'))

            model = ConditionalDiffusionModel(text_dim=3, cond_dim=1, hidden_dim= 1024, time_embed_dim= 128,
                                              num_fc_blocks=10, dropout = 0.00001)
            sampler = DDIMSampler(device=device,noise_steps=1000)
            model = model.to(device)
            trained_model = train_conditional_diffusion(model, U_diff, regions_diff, sampler, num_epochs=3000, batch_size=128, seed=j,device='cuda')

            torch.save(trained_model.state_dict(), f"/home/shenx/tianx/DA/Toy_class/model_save/diffusion_tf_seed{j}_ratio{split_ratio:.1f}.pt")

            result=[]
            X_syn_full, y_syn_full = generate(trained_model,autoencoder.decoder,m,0.5,sampler,seed =j)   

            np.save(f"synthetic_tf/synthetic_X_seed{j}_ratio{split_ratio:.1f}.npy", X_syn_full)
            np.save(f"synthetic_tf/synthetic_y_seed{j}_ratio{split_ratio:.1f}.npy", y_syn_full)
        
        #Compute the wasserstein distance between synthetic sample and real sample.
        X_truth,y_truth,_=generate_imbalanced_data(int(m/2), int(m/2),seed=j)
        truth_samples=combine_XY(X_truth,y_truth.reshape(-1,1))
        gen_samples = combine_XY(X_syn_full, y_syn_full)     
        # Compute cost matrix (Euclidean distances)
        M = ot.dist(truth_samples, gen_samples, metric='euclidean')
        a = np.ones((m,)) / m
        b = np.ones((m,)) / m
        W_distance = ot.emd2(a, b, M)
        W_list[j].append(W_distance)
        print("Wasserstein distance:", W_distance)
        
        
        for m_syn in candidate_m_syn:   
           
            tmp=[]
            if m_syn==0 and split_ratio ==1:
                for alpha_scale in candidate_alpha_scale:
                    tmp.append([np.inf, np.inf])
                result.append(tmp)
                continue
            for alpha_scale in candidate_alpha_scale:
                if m_syn ==0:
                    X_train_combined = XY_reg[:,:-1]
                    y_train_combined = XY_reg[:,-1:]    
                else:
                    X_train_combined = np.vstack((X_syn_full[:int(m_syn*alpha_scale),], X_syn_full[int(m/2):int(m/2+m_syn*(1-alpha_scale)),]))
                    y_train_combined = np.vstack((y_syn_full[:int(m_syn*alpha_scale),], y_syn_full[int(m/2):int(m/2+m_syn*(1-alpha_scale)),]))

                    if split_ratio <1:
                        X_train_combined = np.vstack((X_train_combined,XY_reg[:,:-1]))
                        y_train_combined = np.vstack((y_train_combined,XY_reg[:,-1:]))
                

                # Train and evaluate on validation set
                _, mse_val, mse_test, _, _ = train_and_evaluate(
                    X_train_combined, y_train_combined, X_val, y_val, X_test, y_test,
                    input_dim= X_train_orig.shape[1], batch_size =1024, patience =20)

                print(f"m_syn={m_syn}, alpha_scale={alpha_scale} -> Validation MSE: {mse_val:.4f} -> Test MSE: {mse_test:.4f}")

                tmp.append([mse_val, mse_test])
            result.append(tmp)


        all_result.append(result)
    
    
    
    ax= [all_result[x][:21] for x in range(10)] 
    ax=np.array(ax)
    num_sim, num_split, num_m, num_min_ratio = ax.shape
    val_errors = ax[: , :, :, 0]

    # Find the indices (k, l) that minimize the validation error.
    o, k, l = np.unravel_index(np.argmin(val_errors), val_errors.shape)

    best_ratio = ratio_list[o]
    best_m = candidate_m_syn[k]
    best_alpha = candidate_alpha_scale[l]

    
    n_diff1 = int(best_ratio * n_minority)
    n_diff2 = int(best_ratio * n_majority)

    # Shuffle indices for each group.
    indices_min = np.arange(n_minority)
    np.random.shuffle(indices_min)
    indices_maj = np.arange(n_majority)
    np.random.shuffle(indices_maj)

    # Extract diffusion subset using the shuffled indices.
    XY_diff = np.vstack((
        XY_train_orig[indices_min[:n_diff1], :],
        XY_train_orig[n_minority + indices_maj[:n_diff2], :]
    ))
    regions_diff = np.concatenate((
        regions_train[indices_min[:n_diff1]],
        regions_train[n_minority + indices_maj[:n_diff2]]
    ))

    XY_reg = np.vstack((XY_train_orig[n_diff1:n_minority,:],XY_train_orig[(n_minority+n_diff2):,:]))


    m=2*max(candidate_m_syn)

    # Build file paths.
    file_X = f"synthetic_tf/synthetic_X_seed{j}_ratio{best_ratio:.1f}.npy"
    file_y = f"synthetic_tf/synthetic_y_seed{j}_ratio{best_ratio:.1f}.npy"

    # Load the synthetic data.
    X_syn_full = np.load(file_X)
    y_syn_full = np.load(file_y)

    X_train_combined = np.vstack((
        X_syn_full[:int(best_m*best_alpha), :],
        X_syn_full[int(m/2):int(m/2+best_m*(1-best_alpha)), :]
    ))
    y_train_combined = np.vstack((
        y_syn_full[:int(best_m*best_alpha), :],
        y_syn_full[int(m/2):int(m/2+best_m*(1-best_alpha)), :]
    ))

    X_train_combined = np.vstack((X_train_combined, XY_reg[:, :-1]))
    y_train_combined = np.vstack((y_train_combined, XY_reg[:, -1:]))

    ce_one={}
    # Train and evaluate on validation set
    _, mse_val, mse_test, ypred_val, ypred_test = train_and_evaluate(
        X_train_combined, y_train_combined, X_val, y_val, X_test, y_test,
        input_dim= X_train_orig.shape[1], batch_size =1024, patience =20)
    ce_one['avg']=mse_test
    
    _, mse_val, mse_test, ypred_val, ypred_test = train_and_evaluate(
    X_train_combined, y_train_combined, X_val, y_val, X_test[regions_test==1,], y_test[regions_test==1,],
    input_dim= X_train_orig.shape[1], batch_size =1024, patience =20)
    ce_one['major']=mse_test
        
        
    _, mse_val, mse_test, ypred_val, ypred_test = train_and_evaluate(
    X_train_combined, y_train_combined, X_val, y_val, X_test[regions_test==0,], y_test[regions_test==0,],
    input_dim= X_train_orig.shape[1], batch_size =1024, patience =20)
    ce_one['minor']=mse_test       
    
    ce_tf.append(ce_one)
    list_result.append(all_result)


In [13]:
import pickle

with open("result/res_tf.pkl", "wb") as f:
    pickle.dump([list_result,W_list,list_origin,ce_tf], f)

### Summary statistics for Table 1

In [6]:
[np.mean([x['major'] for x in ce_tf]),n.mean([x['minor'] for x in ce_tf]),np.mean([x['avg'] for x in ce_tf])] 

[0.42235888838768004, 0.4954527199268341, 0.4589058130979538]

In [7]:
[np.std([x['major'] for x in ce_tf]),np.std([x['minor'] for x in ce_tf]),np.std([x['avg'] for x in ce_tf])] 

[0.05049949089435363, 0.052002638188493974, 0.015319663270860734]