/
generation_hyper_params.py
53 lines (41 loc) · 1.45 KB
/
generation_hyper_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
class HParams:
def __init__(self):
self.reload_index = 0
# self.resume_epoch = 60000
self.data_location = "./dataset_32_150"
self.save_path = "./models_32_150"
self.category = [
"airplane.npz", "angel.npz", "apple.npz", "butterfly.npz", "bus.npz",
"cake.npz", "fish.npz", "spider.npz", "The Great Wall of China.npz", "umbrella.npz"
] # 10
# self.row_column, self.graph_number = [int(x) for x in self.data_location.replace("/", "").split("_")[-2:]]
self.graph_number = 150
self.row_column = 32
self.mask_prob = 0.10
self.enc_hidden_size = 256 # encoder LSTM h size
self.dec_hidden_size = 512
self.Nz = 128 # encoder output size
self.M = 20
self.dropout = 0.0
self.batch_size = 32
self.eta_min = 0.01
self.R = 0.99995
self.KL_min = 0.2
self.wKL = 0.5
self.lr = 0.001
self.lr_decay = 0.99999
self.min_lr = 0.00003
self.grad_clip_encode = 20.
self.grad_clip = 1.
self.temperature = 0.05
self.max_seq_length = 200
self.min_seq_length = 0
self.Nmax = 0
self.embedding_dim = 128
self.gcn_out_dim = 128
self.words_number = 256
self.picture_size = self.words_number
self.same_category_in_batch = False
self.use_cuda = torch.cuda.is_available()
hp = HParams()