-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.h
159 lines (129 loc) · 3.54 KB
/
utils.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#pragma once
#include "domain.h"
using namespace std;
using namespace SETTINGS;
// Helper functions
const uint numHA = HA_Labels.size();
string print(HA ha){
return HA_Labels[ha];
}
HA to_label(string str){
for(uint i = 0; i < numHA; i++){
if(HA_Labels[i] == str){
return i;
}
}
return 0;
}
// Define valid transitions
vector<HA> all;
vector<HA> get_valid_ha(HA ha, bool use_safe_transitions=USE_SAFE_TRANSITIONS){
if(use_safe_transitions){
return valid_transitions[ha];
}
if(all.size() != 0)
return all;
for(uint i = 0; i < numHA; i++){
all.push_back(i);
}
return all;
}
// Random error distributions
random_device rd;
default_random_engine gen(0);
// Random generator for Bernoulli trial
bool flip(double p){
double rv = ((double) rand())/RAND_MAX;
return rv <= p;
}
// evaluate a logistic function
double logistic(double midpoint, double spread, double input){
return 1.0 / (1.0 + exp(-spread * (input - midpoint)));
}
// If transition is invalid, randomly choose another one
HA correct(HA prev_ha=0, HA ha=0, bool use_safe_transitions=USE_SAFE_TRANSITIONS) {
if(!use_safe_transitions){
return ha;
}
vector<HA> all_possible_ha = get_valid_ha(prev_ha, use_safe_transitions);
if(count(all_possible_ha.begin(), all_possible_ha.end(), ha) == 0) {
int index = rand() % all_possible_ha.size();
ha = all_possible_ha[index];
}
return ha;
}
// randomly transition to another HA
HA pointError(HA prev_ha=0, HA ha=0, double accuracy=POINT_ACCURACY, bool use_safe_transitions=USE_SAFE_TRANSITIONS){
if(!flip(accuracy)){
vector<HA> all_possible_ha = get_valid_ha(prev_ha, use_safe_transitions);
int index = rand() % all_possible_ha.size();
ha = all_possible_ha[index];
}
return ha;
}
// Exponentiate some values, take their sum, then take the log of the sum
double logsumexp(vector<double>& vals) {
if (vals.size() == 0){
return 0;
}
double max_elem = *max_element(vals.begin(), vals.end());
double sum = accumulate(vals.begin(), vals.end(), 0.0,
[max_elem](double a, double b) { return a + exp(b - max_elem); });
return max_elem + log(sum);
}
double logsumexp(double d1, double d2) {
double max_elem = max(d1, d2);
return max_elem + log(exp(d1 - max_elem) + exp(d2 - max_elem));
}
// Calculate pdf of N(mu, sigma) at x, then take the natural log
double logpdf(double x, double mu, double sigma){
return (-log(sigma)) - (0.5*log(2*M_PI)) - 0.5*pow((x - mu)/sigma, 2);
}
// Trimming a string
inline void ltrim(string &s) {
s.erase(s.begin(), find_if(s.begin(), s.end(), [](unsigned char ch) {
return !isspace(ch);
}));
}
inline void rtrim(string &s) {
s.erase(find_if(s.rbegin(), s.rend(), [](unsigned char ch) {
return !isspace(ch);
}).base(), s.end());
}
inline void trim(string &s) {
rtrim(s);
ltrim(s);
}
double logistic2(double x, double mid, double slope) {
return logistic(mid, slope, x);
}
bool sample(double p) {
return flip(p);
}
double Plus(double x, double y) {
return x + y;
}
double Minus(double x, double y) {
return x - y;
}
double Abs(double x) {
return abs(x);
}
double Times(double x, double y) {
return x * y;
}
double DividedBy(double x, double y) {
return x / y;
}
double And(bool x, bool y) {
return x && y;
}
double Or(bool x, bool y) {
return x || y;
}
bool Lt(double x, double y) {
return x < y;
}
bool Gt(double x, double y) {
return x > y;
}