In [1]:
%cd ../
import pandas as pd
from torch.utils.data import Dataset,DataLoader
import torch
import numpy as np
import ast
import torch.nn as nn
import torch.optim as optim
import random
from matplotlib import pyplot as plt
import os
import shutil
from datetime import datetime
import wandb
from utils import set_all_seed, load_exp2_dataset
from mmoe import MMoE
from trainer_and_evaluator import train_MMoE, eval_MMoE

C:\jupyter file\CMU_3\distributed ML\project



In [2]:
set_all_seed(42)
train_dataset, val_dataset, test_dataset = load_exp2_dataset()

Absolute Pearson correlation coefficient: 0.16661069841980908
P-value: 0.0


In [3]:
# 设定各种超参数，wandb日志名，日志存储路径，模型存储路径
train_params = {
    "batch_size": 256,
    "N_epochs": 50,
    "lr": 0.0001
}

model_params={
    'feature_dim': 114,
    'expert_dim': 32,
    'n_expert': 4,
    'n_task': 2,
    'use_gate': True,
    'gate_dropout': 0,
    'tower_dropout': 0,
    'expert_dropout': 0
}

model_name="exp2_MMoE"
if not os.path.exists("model/"+model_name):
    os.makedirs("model/"+model_name) 

train_params_str = "_".join(f"{key}={value}" for key, value in train_params.items())
model_params_str = "_".join(f"{key}={value}" for key, value in model_params.items())
short_model_params_str = "_".join(f"{value}" for key, value in model_params.items())

wandb_name=model_name+":"+train_params_str+"_"+model_params_str

# 使用short_model_params_str是因为windows支持的最长文件名长度仅为260
bestmodel_save_dir=f"model/"+model_name+"/"+train_params_str+"_"+short_model_params_str 

print(wandb_name)
print(bestmodel_save_dir)

wandb.init(project="mmoe", name=wandb_name)

exp2_MMoE:batch_size=256_N_epochs=50_lr=0.0001_feature_dim=114_expert_dim=32_n_expert=4_n_task=2_use_gate=True_gate_dropout=0_tower_dropout=0_expert_dropout=0
model/exp2_MMoE/batch_size=256_N_epochs=50_lr=0.0001_114_32_4_2_True_0_0_0


[34m[1mwandb[0m: Currently logged in as: [33myuntaozh[0m ([33mzhengyuntao[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
mymodel = MMoE(**model_params)

nParams = sum([p.nelement() for p in mymodel.parameters()])
print('Number of parameters: %d' % nParams)

nParams_in_mmoe=0
for name,p in mymodel.named_parameters():
    if name.startswith("Expert_Gate"):
        nParams_in_mmoe=nParams_in_mmoe+p.nelement()
print('Number of parameters in MMoE: %d' % nParams_in_mmoe)

Number of parameters: 11978
Number of parameters in MMoE: 9976


In [5]:
losses, val_losses, adam_batch_loss= train_MMoE(mymodel,
                                           train_dataset,
                                           val_dataset,
                                           bestmodel_save_dir,
                                           **train_params)

Epoch=0,train_loss=1.1298118829727173,val_loss=0.8646352887153625
current epoch is the best so far. Saving model...
Epoch=1,train_loss=0.8385598063468933,val_loss=0.7746270298957825
current epoch is the best so far. Saving model...
Epoch=2,train_loss=0.633605420589447,val_loss=0.498717725276947
current epoch is the best so far. Saving model...
Epoch=3,train_loss=0.4201555550098419,val_loss=0.36541083455085754
current epoch is the best so far. Saving model...
Epoch=4,train_loss=0.3377324938774109,val_loss=0.317330539226532
current epoch is the best so far. Saving model...
Epoch=5,train_loss=0.2973834276199341,val_loss=0.28491726517677307
current epoch is the best so far. Saving model...
Epoch=6,train_loss=0.274476021528244,val_loss=0.26628434658050537
current epoch is the best so far. Saving model...
Epoch=7,train_loss=0.25862789154052734,val_loss=0.2570919692516327
current epoch is the best so far. Saving model...
Epoch=8,train_loss=0.2460143268108368,val_loss=0.2457582652568817
curren

In [6]:
# load best model based on validation
mybestmodel = MMoE(**model_params)
mybestmodel.load_state_dict(torch.load(bestmodel_save_dir))

<All keys matched successfully>

In [7]:
auc1, auc2=eval_MMoE(mybestmodel, test_dataset)

AUC: 0.9668353143810063
AUC: 0.9906815316952533


In [8]:
wandb.finish()

VBox(children=(Label(value='1.244 MB of 1.244 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
task1_loss,█▄▆▃▅▅▃▃▂▂▂▂▂▂▂▂▂▂▃▂▂▁▁▃▁▃▂▁▁▂▂▁▂▂▁▂▁▂▃▂
task2_loss,██▅▃▂▂▂▁▂▂▂▂▂▁▂▁▂▂▂▂▁▁▂▂▂▁▂▁▂▁▁▁▂▁▁▂▁▂▂▁
task_0/expert_0_weight,▁▁▁█▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
task_0/expert_1_weight,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▄▁▁▁▁▁▁▁
task_0/expert_2_weight,▆▄█▄▄▅▅▆▆█▅▅▄▃▃▃▂▄▃▃▄▂▂▃▅▂▃▁▃▃▃▄▃▁▂▂▁▂▂▁
task_0/expert_3_weight,▃▅▁▅▅▄▄▃▃▁▄▄▅▆▆▆▇▅▆▆▅▇▇▆▄▇▆█▆▆▅▅▆█▇▇█▇▇█
task_1/expert_0_weight,▁▁▂▂▂▂▂▂▂▂▂▁▂▅▅▃▃▂▄▃▂▃▂▂▃▄▂▃▅▅▆▆▅▇▅▆▆▅█▄
task_1/expert_1_weight,▁▁▄█▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
task_1/expert_2_weight,██▇▆▇▇▇▇▇▇▇█▇▄▄▆▆▇▅▆▇▆▇▇▆▅▇▆▄▄▃▃▄▂▄▃▃▄▁▅
task_1/expert_3_weight,█▁▁▂▁▁▁▁▁█▁▁▁▁▁▁▂▁▃▁▁▁▁▁▁▁▁▁▁▁▁▅▁▁▁▁▂▁▁▁

0,1
task1_loss,0.0562
task2_loss,0.0649
task_0/expert_0_weight,0.0
task_0/expert_1_weight,0.00042
task_0/expert_2_weight,0.03841
task_0/expert_3_weight,0.96117
task_1/expert_0_weight,0.15859
task_1/expert_1_weight,0.0
task_1/expert_2_weight,0.84141
task_1/expert_3_weight,0.0
