In [1]:
import pandas as pd
import torch
from torch_rechub.models.multi_task import SharedBottom, MMOE, PLE, AITM, fus_moe_add
from torch_rechub.models.multi_task.mtl_lib import AdaTTSp
from torch_rechub.trainers import MTLTrainer
from torch_rechub.basic.features import DenseFeature, SparseFeature
from torch_rechub.utils.data import DataGenerator
from torch_rechub.single2fus import single2fus

In [None]:
model_name = 'AdaTTSp' # select the model you want to train, 'fus_moe_add' is our method
epoch = 100  #100
learning_rate = 1e-3
batch_size = 1024
weight_decay = 1e-6
save_dir = './save_dir'
seed = 2024
torch.manual_seed(seed) 

In [None]:
device = 'cpu'
if torch.cuda.is_available():
    print('cuda ready...')
    device = 'cuda:2'

In [4]:
def get_census_data_dict(model_name, data_path="data/census-income"):
    df_train = pd.read_csv(data_path + '/census_income_train.csv')
    df_val = pd.read_csv(data_path + '/census_income_val.csv')
    df_test = pd.read_csv(data_path + '/census_income_test.csv')
    print("train : val : test = %d %d %d" % (len(df_train), len(df_val), len(df_test)))
    train_idx, val_idx = df_train.shape[0], df_train.shape[0] + df_val.shape[0]
    data = pd.concat([df_train, df_val, df_test], axis=0)
    data = data.fillna(0)
    #task 1 (as cvr): main task, income prediction
    #task 2(as ctr): auxiliary task, marital status prediction
    data.rename(columns={'income': 'cvr_label', 'marital status': 'ctr_label'}, inplace=True)
    data["ctcvr_label"] = data['cvr_label'] * data['ctr_label']

    col_names = data.columns.values.tolist()
    dense_cols = ['age', 'wage per hour', 'capital gains', 'capital losses', 'divdends from stocks', 'num persons worked for employer', 'weeks worked in year']
    sparse_cols = [col for col in col_names if col not in dense_cols and col not in ['cvr_label', 'ctr_label', 'ctcvr_label']]
    print("sparse cols:%d dense cols:%d" % (len(sparse_cols), len(dense_cols)))
    #define dense and sparse features
    label_cols = ['cvr_label', 'ctr_label']  #the order of labels can be any
    used_cols = sparse_cols + dense_cols
    features = [SparseFeature(col, data[col].max()+1, embed_dim=4)for col in sparse_cols] \
                + [DenseFeature(col) for col in dense_cols]
    x_train, y_train = {name: data[name].values[:train_idx] for name in used_cols}, data[label_cols].values[:train_idx]
    x_val, y_val = {name: data[name].values[train_idx:val_idx] for name in used_cols}, data[label_cols].values[train_idx:val_idx]
    x_test, y_test = {name: data[name].values[val_idx:] for name in used_cols}, data[label_cols].values[val_idx:]
    return features, x_train, y_train, x_val, y_val, x_test, y_test

In [None]:
# Set model hyperparameters
if model_name == "SharedBottom":
    features, x_train, y_train, x_val, y_val, x_test, y_test = get_census_data_dict(model_name)
    task_types = ["classification", "classification"]
    model = SharedBottom(features, task_types, bottom_params={"dims": [16]}, tower_params_list=[{"dims": [8]}, {"dims": [8]}])
elif model_name == "MMOE":
    features, x_train, y_train, x_val, y_val, x_test, y_test = get_census_data_dict(model_name)
    task_types = ["classification", "classification"]
    model = MMOE(features, task_types, 12, expert_params={"dims": [16]}, tower_params_list=[{"dims": [8]}, {"dims": [8]}])
elif model_name == "PLE":
    features, x_train, y_train, x_val, y_val, x_test, y_test = get_census_data_dict(model_name)
    task_types = ["classification", "classification"]
    model = PLE(features, task_types, n_level=2, n_expert_specific=2, n_expert_shared=2, expert_params={"dims": [16]}, tower_params_list=[{"dims": [8]}, {"dims": [8]}])
elif model_name == "AITM":
    task_types = ["classification", "classification"]
    features, x_train, y_train, x_val, y_val, x_test, y_test = get_census_data_dict(model_name)
    model = AITM(features, 2, bottom_params={"dims": [16]}, tower_params_list=[{"dims": [8]}, {"dims": [8]}])
elif model_name == "AdaTTSp":
    task_types = ["classification", "classification"]
    features, x_train, y_train, x_val, y_val, x_test, y_test = get_census_data_dict(model_name)
    model = AdaTTSp(input_dim=features, expert_out_dims=[[16]], num_tasks=2, num_task_experts=12)
elif model_name == "fus_moe_add":
    task_types = ["classification", "classification"]
    features, x_train, y_train, x_val, y_val, x_test, y_test = get_census_data_dict(model_name)
    model = fus_moe_add(features, task_types, n_level=3, n_expert_specific=2, expert_params1 ={"dims": [8,8]}, expert_params2 ={"dims": [16,8]}, tower_params_list=[{"dims": [8]}, {"dims": [8]}])
    

In [6]:
if model_name == "fus_moe_add":
    model_path = "path/to/save_dir/single_model_name_1.pth"
    single_list = ['mlp1', 'mlp2', 'mlp3']
    fus_list = ['cgc_layers.0.experts_specific.0','cgc_layers.1.experts_specific.0', 'cgc_layers.2.experts_specific.0']
    single2fus(model_path, single_list, fus_list, model)

    model_path = "path/to/save_dir/single_model_name_2.pth"
    single_list = ['mlp1', 'mlp2', 'mlp3']
    fus_list = ['cgc_layers.0.experts_specific.2','cgc_layers.1.experts_specific.2', 'cgc_layers.2.experts_specific.2']
    single2fus(model_path, single_list, fus_list, model)

    model_path = "path/to/save_dir/single_model_name_3.pth"
    single_list = ['mlp1', 'mlp2', 'mlp3']
    fus_list = ['cgc_layers.0.experts_specific.1','cgc_layers.1.experts_specific.1', 'cgc_layers.2.experts_specific.1']
    single2fus(model_path, single_list, fus_list, model)

    model_path = "path/to/save_dir/single_model_name_4.pth"
    single_list = ['mlp1', 'mlp2', 'mlp3']
    fus_list = ['cgc_layers.0.experts_specific.3','cgc_layers.1.experts_specific.3', 'cgc_layers.2.experts_specific.3']
    single2fus(model_path, single_list, fus_list, model)

In [8]:
dg = DataGenerator(x_train, y_train)
train_dataloader, val_dataloader, test_dataloader = dg.generate_dataloader(x_val=x_val, y_val=y_val, x_test=x_test, y_test=y_test, batch_size=batch_size)
mtl_trainer = MTLTrainer(model, task_types=task_types, optimizer_params={"lr": learning_rate, "weight_decay": weight_decay}, n_epoch=epoch, earlystop_patience=50, device=device, model_path=save_dir)
# started training
file_path = 'path/to/save/output/result_{}_{}.txt'.format(model_name)
mtl_trainer.fit(train_dataloader, val_dataloader, mode = 'mark1', seed = 'mark2', file_path = file_path)

In [None]:
auc = mtl_trainer.evaluate(mtl_trainer.model, test_dataloader)
print(f'test auc: {auc}')

epo = ['test_result']
my_list = epo + auc
my_list = ', '.join(map(str, my_list))

try:
    with open(file_path, 'a') as file:
        file.write(my_list + '\n')
except Exception as e:  
    print(f"An error occurred while adding to the file: {e}")