-
Notifications
You must be signed in to change notification settings - Fork 0
/
learn_data_neural_comp.py
114 lines (88 loc) · 3.21 KB
/
learn_data_neural_comp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import argparse
from h5py import File
import logging
import numpy as np
import os
import random
import torch
from module.module import Module
from learner.data_neural_complexity_learner import DataNeuralComplexityLearner
from module.model import Model
###############################################################################
def main():
logging.basicConfig(level=logging.INFO)
# logging.getLogger().disabled = True
logging.StreamHandler.terminator = ""
arg_parser = argparse.ArgumentParser(description='')
arg_parser.add_argument(
"data", metavar="data", type=str,
help="data")
arg_parser.add_argument(
"path", metavar="path", type=str,
help="path csv")
arg_parser.add_argument(
"--seed", metavar="seed", default=0, type=int,
help="seed")
# SGD
arg_parser.add_argument(
"--batch_size", metavar="batch_size", default=1, type=int,
help="batch_size")
arg_parser.add_argument(
"--val", metavar="val", default=0.1, type=float,
help="val")
arg_parser.add_argument(
"--decay", metavar="decay", default=0.0, type=float,
help="decay")
arg_parser.add_argument(
"--epoch", metavar="epoch", default=1, type=int,
help="epoch")
arg_parser.add_argument(
"--save_size", metavar="save_size", default=10, type=int,
help="save_size")
# ----------------------------------------------------------------------- #
arg_list = arg_parser.parse_known_args()[0]
data = arg_list.data
path = arg_list.path
seed = arg_list.seed*3
# SGLD
batch_size = arg_list.batch_size
val = arg_list.val
decay = arg_list.decay
epoch = arg_list.epoch
save_size = arg_list.save_size
# ----------------------------------------------------------------------- #
device_list = ["cuda", "cpu"]
# ----------------------------------------------------------------------- #
data = File(os.path.join("data", data+".h5"), "r")
path = os.path.join(os.path.dirname(__file__), path)
x_train = np.array(data["x_train"])
y_train = np.array(data["y_train"])
x_mean, x_std = Module("MeanStd")(x_train)
x_train = Module("StandardScaler")(x_train, x_mean=x_mean, x_std=x_std)
y_train = np.expand_dims(y_train, 1)
y_train = y_train.astype(np.int64)
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
permutation = np.arange(x_train.shape[0])
np.random.shuffle(permutation)
x_train = x_train[permutation]
y_train = y_train[permutation]
m = len(x_train)-int(val*len(x_train))
# We take the two sets
x_val = x_train[m:]
y_val = y_train[m:]
x_train = x_train[:m]
y_train = y_train[:m]
# ----------------------------------------------------------------------- #
# We learn the models
logging.info("We learn the models...\n")
model = Model("MNISTModel", seed=seed+1)
model.to(device_list)
learner = DataNeuralComplexityLearner(
path, model, decay, batch_size, epoch, save_size,
seed=seed+2)
learner.fit(x_train, y_train, x_val, y_val)
###############################################################################
if __name__ == "__main__":
main()