-
Notifications
You must be signed in to change notification settings - Fork 0
/
TransE.cpp
418 lines (398 loc) · 15.7 KB
/
TransE.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
string version;
char buf[100000],buf1[100000];
int relation_num,entity_num;
map<string,int> relation2id,entity2id;
map<int,string> id2entity,id2relation;
//all these dictionaries have been explained as and when required
map<int,map<int,int> > left_entity,right_entity;
map<int,double> left_num,right_num;
class Train{
public:
map<pair<int,int>, map<int,int> > ok;
void add(int x,int y,int z)
{
//this dictionary is used for keeping the freebase entities left side.
fb_h.push_back(x);
//this is used for keeping the relationship label
fb_r.push_back(z);
//this dictionary is used for keeping the freebase entities right side.
fb_l.push_back(y);
//this is used for keeping the relationship like this (x is related to y by the relation z)
ok[make_pair(x,z)][y]=1;
}
void run(int n_in,double rate_in,double margin_in,int method_in)
{
//
n = n_in;
rate = rate_in;
margin = margin_in;
method = method_in;
//represents an array of n dimensional relationship vectors
relation_vec.resize(relation_num);
// and allocate each individual vector a size of n(dimension of embedding)
for (int i=0; i<relation_vec.size(); i++)
relation_vec[i].resize(n);
//represents an array of n dimensional entity vectors
entity_vec.resize(entity_num);
// and allocate each individual vector a size of n
for (int i=0; i<entity_vec.size(); i++)
entity_vec[i].resize(n);
//temporary array of n dimensional relation vectors
relation_tmp.resize(relation_num);
// and allocate each individual vector a size of n(dimension of embedding)
for (int i=0; i<relation_tmp.size(); i++)
relation_tmp[i].resize(n);
//temporary array of n dimensional entity vectors
entity_tmp.resize(entity_num);
// and allocate each individual vector a size of n(dimension of embedding)
for (int i=0; i<entity_tmp.size(); i++)
entity_tmp[i].resize(n);
//iterating through all relationship ids
for (int i=0; i<relation_num; i++)
{
//this is initializing the relationship vector with initial random values.
//relation_vec[i] is the relationship vector corresponding to the relationshipid 'i'
for (int ii=0; ii<n; ii++)
relation_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
}
//iterating through all entity ids
for (int i=0; i<entity_num; i++)
{
//this is initializing the entity vector with initial random values.
//entity_vec[i] is the entity vector corresponding to the entityid 'i'
for (int ii=0; ii<n; ii++)
entity_vec[i][ii] = randn(0,1.0/n,-6/sqrt(n),6/sqrt(n));
// this is normalising the entity vector by divinding each entry in the vector by the size of the vector
norm(entity_vec[i]);
}
bfgs();
}
private:
int n,method;
double res;//loss function value
double count,count1;//loss function gradient
double rate,margin;
double belta;
vector<int> fb_h,fb_l,fb_r;
vector<vector<int> > feature;
vector<vector<double> > relation_vec,entity_vec;
vector<vector<double> > relation_tmp,entity_tmp;
double norm(vector<double> &a)
{
double x = vec_len(a);
if (x>1)
for (int ii=0; ii<a.size(); ii++)
a[ii]/=x;
return 0;
}
int rand_max(int x)
{
int res = (rand()*rand())%x;
while (res<0)
res+=x;
return res;
}
void bfgs()
{
res=0;
int nbatches=100; //number of batches
int nepoch = 1000; // number of iterat
//fb_h has all the head entities in the dataset
int batchsize = fb_h.size()/nbatches; // number of elements to be used in a batch
//main for loop used for each individual iteration
for (int epoch=0; epoch<nepoch; epoch++)
{
res=0;
//iterating through the number of batches
for (int batch = 0; batch<nbatches; batch++)
{
//working on the current relationship vector by storing it in the temporary vector
relation_tmp=relation_vec;
entity_tmp = entity_vec;
//iterating through the batches
for (int k=0; k<batchsize; k++)
{
//store in i a uniformly generated random number > 0 and <= fb_h.size()
//i.e. get a random index into the list of head entities
int i=rand_max(fb_h.size());
//i.e. get a random entity_id
int j=rand_max(entity_num);
//fb_r[i] = relationship id corresponding to the random index i
//right_num[fb_r[i]] = count of the average number of times an entity
// has appeared in this relation(with id fb_r[i])
//similarly for right hand side
//pr is some probability metric.
double pr = 1000*right_num[fb_r[i]]/(right_num[fb_r[i]]+left_num[fb_r[i]]);
if (method == 0)
pr = 500;
//a decision on what to corrupt the
//head relation or the tail relation is taken
//based on a random number generated and the value of pr
if (rand()%1000<pr)
{
// the right entity is corrupted
//go on iterating untill there exists a relation of the form
//entityid = fb_h[i] is related to entityid = j by the relationshipid = fb_r[i]
//So, find an entity id such that it is not related with the entityid = fb_h[i]
//using the relationshipid = fb_r[i]
while (ok[make_pair(fb_h[i],fb_r[i])].count(j)>0)
j=rand_max(entity_num);
//this method is used to implement stochastic gradient decent using
//this current set of triplets
//fb_h[i],fb_l[i],fb_r[i] is the actual triplet
//fb_h[i],j,fb_r[i] is the right corrupted triplet
train_kb(fb_h[i],fb_l[i],fb_r[i],fb_h[i],j,fb_r[i]);
}
else
{
// the left entity is corrupted
//the same logic of finding the first enttiy that does not have relation
while (ok[make_pair(j,fb_r[i])].count(fb_l[i])>0)
j=rand_max(entity_num);
//fb_h[i],fb_l[i],fb_r[i] is the actual triplet
//j,fb_h[i],fb_r[i] is the right corrupted triplet
train_kb(fb_h[i],fb_l[i],fb_r[i],j,fb_l[i],fb_r[i]);
}
//normalising to meet the additional constraints that the
//entity embedding has a magnitude of 1.
norm(relation_tmp[fb_r[i]]);
norm(entity_tmp[fb_h[i]]);
norm(entity_tmp[fb_l[i]]);
norm(entity_tmp[j]);
}
//updated the relationship and the entity vectors
relation_vec = relation_tmp;
entity_vec = entity_tmp;
}
//printing the number of ecpochs
cout<<"epoch:"<<epoch<<' '<<res<<endl;
//write the learnt embeddings to a file
FILE* f2 = fopen(("relation2vec."+version).c_str(),"w");
FILE* f3 = fopen(("entity2vec."+version).c_str(),"w");
for (int i=0; i<relation_num; i++)
{
for (int ii=0; ii<n; ii++)
fprintf(f2,"%.6lf\t",relation_vec[i][ii]);
fprintf(f2,"\n");
}
for (int i=0; i<entity_num; i++)
{
for (int ii=0; ii<n; ii++)
fprintf(f3,"%.6lf\t",entity_vec[i][ii]);
fprintf(f3,"\n");
}
fclose(f2);
fclose(f3);
}
}
double res1;
double calc_sum(int e1,int e2,int rel)
{
//this function is used to compute the norm of vector (e2-e1-rel)
double sum=0;
//L1_flag is used to decide if we want L1 norm or L2 norm
if (L1_flag)
for (int ii=0; ii<n; ii++)
sum+=fabs(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);
else
for (int ii=0; ii<n; ii++)
sum+=sqr(entity_vec[e2][ii]-entity_vec[e1][ii]-relation_vec[rel][ii]);
return sum;
}
void gradient(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)
{
//looping through all the entries in the n dimensional vector
for (int ii=0; ii<n; ii++)
{
// x holds the the ii_th entry of gradient of vector (e2_a-e1_a-rel_a)
double x = 2*(entity_vec[e2_a][ii]-entity_vec[e1_a][ii]-relation_vec[rel_a][ii]);
if (L1_flag)
//this part is unclear
if (x>0)
x=1;
else
x=-1;
//make changes to the entity and relationship vectors according to the rate
//increase e1_a and rel1_a and decrease e2_a
relation_tmp[rel_a][ii]-=-1*rate*x;
entity_tmp[e1_a][ii]-=-1*rate*x;
entity_tmp[e2_a][ii]+=-1*rate*x;
// x holds the the ii_th entry of gradient of vector (e2_b-e1_b-rel_b)
x = 2*(entity_vec[e2_b][ii]-entity_vec[e1_b][ii]-relation_vec[rel_b][ii]);
if (L1_flag)
if (x>0)
x=1;
else
x=-1;
//make changes to the entity and relationship vectors according to the rate
//increase e1_b and rel1_b and decrease e2_b
relation_tmp[rel_b][ii]-=rate*x;
entity_tmp[e1_b][ii]-=rate*x;
entity_tmp[e2_b][ii]+=rate*x;
}
}
//e1_a, e2_a, rel_a is the proper relation triplet
//e1_b, e2_b, rel_b is the corrupted relation triplet
void train_kb(int e1_a,int e2_a,int rel_a,int e1_b,int e2_b,int rel_b)
{
//sum1 is the norm of the vector e2_a-e1_a-rel_a
double sum1 = calc_sum(e1_a,e2_a,rel_a);
//sum2 is the norm of the vector e2_b-e1_b-rel_b
double sum2 = calc_sum(e1_b,e2_b,rel_b);
//since we only have to consider the postive parts of the loss function
if (sum1+margin>sum2)
{
//res is the loss function value
res+=margin+sum1-sum2;
//compue the gradient and update the relation and entity vectors accordingly
gradient( e1_a, e2_a, rel_a, e1_b, e2_b, rel_b);
}
}
};
Train train;
void prepare()
{
/* the file(entity2id.txt) has a structure like this
/m/06rf7 0
/m/0c94fn 1
that is mid and a uinique integer id assigned to it
*/
FILE* f1 = fopen("../data/entity2id.txt","r");
/* the file(relation2id.txt) has a structure like this
/people/appointed_role/appointment./people/appointment/appointed_by 0
/location/statistical_region/rent50_2./measurement_unit/dated_money_value/currency 1
*/
FILE* f2 = fopen("../data/relation2id.txt","r");
int x;
//read a line from the "entity2id.txt" file and store the mid in buff and it's unique id in x
while (fscanf(f1,"%s%d",buf,&x)==2)
{
string st=buf;
//entity2id is a map<string, int> that is it is used to create a mapping(dictionary) from mid to it's unique ID.
entity2id[st]=x;
//it creates a reverse mapping of the above mapping
id2entity[x]=st;
//counter to count the number of entities.
entity_num++;
}
while (fscanf(f2,"%s%d",buf,&x)==2)
{
string st=buf;
//same job is done for the relations thing as it is done for the entity thing.
relation2id[st]=x;
id2relation[x]=st;
relation_num++;
}
/* reading the training file "train.txt" and it's content looks like this
/m/027rn /m/06cx9 /location/country/form_of_government
i.e. the entitties and the relation between them
*/
FILE* f_kb = fopen("../data/train.txt","r");
while (fscanf(f_kb,"%s",buf)==1)
{
//mid1 read into s1
string s1=buf;
fscanf(f_kb,"%s",buf);
//mid2 read into s2
string s2=buf;
fscanf(f_kb,"%s",buf);
//relationships read into s3
string s3=buf;
//checks to see that all these entities have unique ids.
//if not print missing entity
if (entity2id.count(s1)==0)
{
cout<<"miss entity:"<<s1<<endl;
}
if (entity2id.count(s2)==0)
{
cout<<"miss entity:"<<s2<<endl;
}
//do the same check for relations but if not present put it into the dictionary and increase the number of relationship count
if (relation2id.count(s3)==0)
{
relation2id[s3] = relation_num;
relation_num++;
}
//all the mappings are based on the unique ids assigned to the relationships and entities
//mapping of relationship to left entity
left_entity[relation2id[s3]][entity2id[s1]]++;
//maping of relationship to right entity
right_entity[relation2id[s3]][entity2id[s2]]++;
//add this current relation for training purposes, the train class uses dictionaries to save the left side and right
//side entities as well as the realtionship
train.add(entity2id[s1],entity2id[s2],relation2id[s3]);
}
//looping through all relationship ids
for (int i=0; i<relation_num; i++)
{
double sum1=0,sum2=0;
//looping through all the left side entities related to this relation
for (map<int,int>::iterator it = left_entity[i].begin(); it!=left_entity[i].end(); it++)
{
// this vatriable is used for counting the number of unique ids linked with this relation
sum1++;
// this variable is used for counting the total number of ids linked with this relation
sum2+=it->second;
}
//this dictionary keeps a count of the average number of times an entity has appeared in this relation with id i
left_num[i]=sum2/sum1;
}
//looping through all relationship ids
for (int i=0; i<relation_num; i++)
{
double sum1=0,sum2=0;
//doing the same stuff for the right hand side entity
for (map<int,int>::iterator it = right_entity[i].begin(); it!=right_entity[i].end(); it++)
{
sum1++;
sum2+=it->second;
}
//this dictionary keeps a count of the average number of times an entity has appeared in this relation with id i
right_num[i]=sum2/sum1;
}
//printing the relationship number and the entity numbers
cout<<"relation_num="<<relation_num<<endl;
cout<<"entity_num="<<entity_num<<endl;
fclose(f_kb);
}
int ArgPos(char *str, int argc, char **argv) {
/*
this function is used for locating a particular sting(str) in the arguments passed while running this code(argv)
*/
int a;
for (a = 1; a < argc; a++) if (!strcmp(str, argv[a])) {
if (a == argc - 1) {
printf("Argument missing for %s\n", str);
exit(1);
}
return a;
}
return -1;
}
int main(int argc,char**argv)
{
srand((unsigned) time(NULL));
//-size : the embedding size k, d
//-rate : learing rate
//-method: 0 - unif, 1 - bern
int method = 1;
int n = 100; // dimension of the embedding space
double rate = 0.001;
double margin = 1;
int i;
if ((i = ArgPos((char *)"-size", argc, argv)) > 0) n = atoi(argv[i + 1]); // find if size is provided as an argument and if so store it's value in n
if ((i = ArgPos((char *)"-margin", argc, argv)) > 0) margin = atoi(argv[i + 1]); // similarly if margin is mentioned store it in marge
if ((i = ArgPos((char *)"-method", argc, argv)) > 0) method = atoi(argv[i + 1]); // similarly for method
cout<<"size = "<<n<<endl;
cout<<"learing rate = "<<rate<<endl;
cout<<"margin = "<<margin<<endl;
if (method)
version = "bern";
else
version = "unif";
cout<<"method = "<<version<<endl;
//created all the dictionaries, load all the relationships in the train class and then we will start training
prepare();
train.run(n,rate,margin,method);
}