-
Notifications
You must be signed in to change notification settings - Fork 108
APIs
FlameSky edited this page Jun 9, 2022
·
7 revisions
The MMSA_run
function is the main function of this project. It runs MSA experiments on datasets and models specified in the parameters.
Definition:
def MMSA_run(
model_name: str, dataset_name: str, config_file: str = "",
config: dict = None, seeds: list = [], is_tune: bool = False,
tune_times: int = 50, feature_T: str = "", feature_A: str = "",
feature_V: str = "", model_save_dir: str = "", res_save_dir: str = "",
log_dir: str = "", gpu_ids: list = [0], num_workers: int = 4,
verbose_level: int = 1
)
Parameters:
-
model_name
(required): Name of MSA model, see Supported Models for details. -
dataset_name
(required): Name of MSA dataset, see Supported Datasets for details. -
config_file
: Path to config file. Default config files will be used if not specified. See Config Files for details. -
config
: Config in the format of Python dict. Used to override arguments inconfig_file
. Ignored in tune mode. -
seeds
: List of seeds. Default:[1111, 1112, 1113, 1114, 1115]
-
is_tune
: Tuning mode switch. See Tuning Mode for details. Default:False
-
tune_times
: # Sets of hyper parameters to tune. Default:50
-
feature_T
: Path to text feature file. Provide an empty string to use default BERT features. Default:""
-
feature_A
: Path to audio feature file. Provide an empty string to use default features provided by dataset creators. Default:""
-
feature_V
: Path to video feature file. Provide an empty string to use default features provided by dataset creators. Default:""
-
model_save_dir
: Path to save trained models. Default:~/MMSA/saved_models
-
res_save_dir
: Path to save csv results. Default:~/MMSA/results
-
log_dir
: Path to save log files. Default:~/MMSA/logs
-
gpu_ids
: GPUs to use. Will assign the most memory-free gpu if an empty list is provided. Default:[0]
. Currently only supports single gpu. -
num_workers
: Number of workers used to load data. Default:4
-
verbose_level
: Verbose level of stdout.0
for error,1
for info,2
for debug. Default:1
Example Usage:
from MMSA import MMSA_run
# run lmf on mosi with default params
MMSA_run('lmf', 'mosi')
# tune mult on mosei with default param ranges
MMSA_run('mult', 'mosi', is_tune=True, seeds=[1111])
The get_config_regression
function retrieves config from a config file