In [1]:
import argparse
from utils import str2bool
import os 
parser = argparse.ArgumentParser()

'''framework'''
parser.add_argument('--task', default = 'train',
                    help = 'train | test | others')

parser.add_argument('--data_dir', default = '../data',
                    help = 'data directory')

parser.add_argument('--file_train', default = 'train.txt', 
                    help = 'training data file name')

parser.add_argument('--base_model_dir', default = '../nats_results/sscl_models', 
                    help = 'base model dir')

parser.add_argument('--n_epoch', default = 10, type = int, 
                    help = 'epochs')

parser.add_argument('--batch_size', default = 50, type = int, 
                    help = 'batch size')

parser.add_argument('--checkpoint', default = 300, type = int,
                    help = 'how often do you want to save model')

parser.add_argument('--continue_training', default = True, type = str2bool, 
                    help = 'do you want to continue train previous model saved in computer')

parser.add_argument('--train_base_model', default = False, type = str2bool,
                    help = 'do you want to train word embedding')

parser.add_argument('--use_optimal_model', default = True, type = str2bool,
                    help = 'weather to use optimal model')

parser.add_argument('--model_optimal_key', default = '0,0', 
                    help = 'epoch, batch')

#learning_rate and scheduler
parser.add_argument('--lr_schedule', default = 'warm-up', 
                    help = 'warm-up | build-in | None')

parser.add_argument('--learning_rate', default = 0.0002, type = float, 
                    help = 'learning rate')

parser.add_argument('--grad_clip', default = 2.0, type = float, 
                    help = 'clip the gradient norm')

parser.add_argument('--step_size', default = 2, type = int, 
                    help = 'stepLR scheduler decay period')

parser.add_argument('--step_decay', default = 0.8, type = int, 
                    help = 'decay rate')

parser.add_argument('--warmup_step', default = 3000, type = int, 
                    help = 'the step where the learning rate go to top if apply warmup scheduler')

parser.add_argument('--model_size', default = 100000, type = int,
                    help = 'to modify learning rate')


'''user specified argument'''
parser.add_argument('--device', default = 'cuda:0',
                    help = 'device')

#sscl
parser.add_argument('--distance', default = 'cosine', 
                    help = 'method to compute distance')

parser.add_argument('--emb_size', default = 128, type = int, 
                    help = 'embedding size')

parser.add_argument('--max_seq_len', default = 30, type = int, 
                    help = 'maximum sequence length')

parser.add_argument('--min_seq_len', default = 3, type = int, 
                    help = 'minimum sequence length')

parser.add_argument('--smooth_factor', default = 0.5, type = float, 
                    help = 'smooth factor of attention')

#word2vec
parser.add_argument('--file_train_w2v', default = 'train_w2v.txt', 
                    help = 'to train word embedding(file)')

parser.add_argument('--window', default = 5, type = int,
                    help = 'window size')

parser.add_argument('--min_count', default = 10, type = int, 
                    help = 'the minimum count of word')

parser.add_argument('--workers', default = 8, type = int, 
                    help = 'workers')

#kmeans
parser.add_argument('--kmeans_init', default = 'vanilla', 
                    help = 'vanilla | kmeans_init.txt')

parser.add_argument('--kmeans_seeds', default = 0, type = int, 
                    help = 'kmeans random seed')

parser.add_argument('--n_clusters', default = 30, type = int, 
                    help = 'the number of clusters')

parser.add_argument('--n_keywords', default = 10, type = int, 
                    help = 'the number of closest words to choose')

#report
parser.add_argument('--report', default = ['loss', 'lr', 'loss_without_reg', 'loss_reg', 'asp_weight_norm'], type = list, 
                    help  = 'weather to report ["loss", "lr", "loss_without_reg", "loss_reg"]')

parser.add_argument('--monitor', default = ['asp_weight', 'attention'], type = list,
                    help = 'weather to moniter p_norm or attention_norm')

args = parser.parse_args([])

In [2]:
def run_task(args):
    if args.task == 'word2vec':
        from word2vec import run_w2v
        runner = run_w2v(args)
        model = runner.train_model()
        runner.save_vocab_and_vectors()
    
    if args.task[:6] == 'kmeans':
        if args.task == 'kmeans':
            from kmeans import run_kmeans
            runner = run_kmeans(args)
            runner.train()
        
    if args.task[:4] == 'sscl':
        import torch
        args.device = torch.device(args.device)
        from modelSSCL import modelSSCL
        model = modelSSCL(args)
    
        if args.task == 'sscl-train':
            model.train()

In [3]:
args.task = 'word2vec'
run_task(args)

"args.task = 'word2vec'\nrun_task(args)"

In [4]:
args.task = 'kmeans'
run_task(args)

"args.task = 'kmeans'\nrun_task(args)"

In [5]:
args.task = 'sscl-train'
run_task(args)

The size of vocabulary :6273.
{'embedding': Embedding(6273, 128)}
{'asp_weight': Linear(in_features=128, out_features=30, bias=True),
 'aspect_embedding': Embedding(30, 128, padding_idx=0),
 'attn_kernel': Linear(in_features=128, out_features=128, bias=True)}
Total number of trainable parameters 24222.
loading data...
batch number: 5958
writing data...
[>>>>>>>>>>>>>>>>>>>>>>>>>] 100% 
The number of batch is 5958.
epoch:0, batch:5955/5958, lr:8.2e-05, loss:3.1149, time:0.1072hhh
loading data...
batch number: 5958
writing data...
[>>>>>>>>>>>>>>>>>>>>>>>>>] 100% 
The number of batch is 5958.
epoch:1, batch:5955/5958, lr:5.8e-05, loss:3.2104, time:0.2215h
loading data...
batch number: 5958
writing data...
[>>>>>>>>>>>>>>>>>>>>>>>>>] 100% 
The number of batch is 5958.
epoch:2, batch:5955/5958, lr:4.7e-05, loss:3.2529, time:0.3423h
loading data...
batch number: 5958
writing data...
[>>>>>>>>>>>>>>>>>>>>>>>>>] 100% 
The number of batch is 5958.
epoch:3, batch:5955/5958, lr:4.1e-05, loss:2.9