-
-
Notifications
You must be signed in to change notification settings - Fork 467
/
utils.js
414 lines (329 loc) 路 10.4 KB
/
utils.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
const fs = require('fs');
const { env } = require('./env.js');
class FileResponse {
constructor(filePath) {
this.filePath = filePath;
this.headers = {};
this.headers.get = (x) => this.headers[x]
this.exists = fs.existsSync(filePath);
if (this.exists) {
this.status = 200;
this.statusText = 'OK';
let stats = fs.statSync(filePath);
this.headers['content-length'] = stats.size;
this.updateContentType();
let self = this;
this.body = new ReadableStream({
start(controller) {
self.arrayBuffer().then(buffer => {
controller.enqueue(new Uint8Array(buffer));
controller.close();
})
}
});
} else {
this.status = 404;
this.statusText = 'Not Found';
this.body = null;
}
}
updateContentType() {
// Set content-type header based on file extension
const extension = this.filePath.split('.').pop().toLowerCase();
switch (extension) {
case 'txt':
this.headers['content-type'] = 'text/plain';
break;
case 'html':
this.headers['content-type'] = 'text/html';
break;
case 'css':
this.headers['content-type'] = 'text/css';
break;
case 'js':
this.headers['content-type'] = 'text/javascript';
break;
case 'json':
this.headers['content-type'] = 'application/json';
break;
case 'png':
this.headers['content-type'] = 'image/png';
break;
case 'jpg':
case 'jpeg':
this.headers['content-type'] = 'image/jpeg';
break;
case 'gif':
this.headers['content-type'] = 'image/gif';
break;
default:
this.headers['content-type'] = 'application/octet-stream';
break;
}
}
clone() {
return new FileResponse(this.filePath, {
status: this.status,
statusText: this.statusText,
headers: this.headers,
});
}
async arrayBuffer() {
const data = await fs.promises.readFile(this.filePath);
return data.buffer;
}
async blob() {
const data = await fs.promises.readFile(this.filePath);
return new Blob([data], { type: this.headers['content-type'] });
}
async text() {
const data = await fs.promises.readFile(this.filePath, 'utf8');
return data;
}
async json() {
return JSON.parse(await this.text());
}
}
function isValidHttpUrl(string) {
// https://stackoverflow.com/a/43467144
let url;
try {
url = new URL(string);
} catch (_) {
return false;
}
return url.protocol === "http:" || url.protocol === "https:";
}
async function getFile(url) {
// Helper function to get a file, using either the Fetch API or FileSystem API
if (env.useFS && !isValidHttpUrl(url)) {
return new FileResponse(url)
} else {
return fetch(url)
}
}
function dispatchCallback(progressCallback, data) {
if (progressCallback !== null) progressCallback(data);
}
async function getModelFile(modelPath, fileName, progressCallback = null, fatal = true) {
// Initiate session
dispatchCallback(progressCallback, {
status: 'initiate',
name: modelPath,
file: fileName
})
let cache;
if (env.useCache) {
cache = await caches.open('transformers-cache');
}
const request = pathJoin(modelPath, fileName);
let response;
let responseToCache;
if (!env.useCache || (response = await cache.match(request)) === undefined) {
// Caching not available, or model is not cached, so we perform the request
response = await getFile(request);
if (response.status === 404) {
if (fatal) {
throw Error(`File not found. Could not locate "${request}".`)
} else {
// File not found, but this file is optional.
// TODO in future, cache the response
return null;
}
}
if (env.useCache) {
// only clone if cache available
responseToCache = response.clone();
}
}
// Start downloading
dispatchCallback(progressCallback, {
status: 'download',
name: modelPath,
file: fileName
})
const buffer = await readResponse(response, data => {
dispatchCallback(progressCallback, {
status: 'progress',
...data,
name: modelPath,
file: fileName
})
})
// Check again whether request is in cache. If not, we add the response to the cache
if (responseToCache !== undefined && await cache.match(request) === undefined) {
cache.put(request, responseToCache);
}
dispatchCallback(progressCallback, {
status: 'done',
name: modelPath,
file: fileName
});
return buffer;
}
async function fetchJSON(modelPath, fileName, progressCallback = null, fatal = true) {
let buffer = await getModelFile(modelPath, fileName, progressCallback, fatal);
if (buffer === null) {
// Return empty object
return {}
}
let decoder = new TextDecoder('utf-8');
let jsonData = decoder.decode(buffer);
return JSON.parse(jsonData);
}
async function readResponse(response, progressCallback) {
// Read and track progress when reading a Response object
const contentLength = response.headers.get('Content-Length');
if (contentLength === null) {
console.warn('Unable to determine content-length from response headers. Will expand buffer when needed.')
}
let total = parseInt(contentLength ?? '0');
let buffer = new Uint8Array(total);
let loaded = 0;
const reader = response.body.getReader();
async function read() {
const { done, value } = await reader.read();
if (done) return;
let newLoaded = loaded + value.length;
if (newLoaded > total) {
total = newLoaded;
// Adding the new data will overflow buffer.
// In this case, we extend the buffer
let newBuffer = new Uint8Array(total);
// copy contents
newBuffer.set(buffer);
buffer = newBuffer;
}
buffer.set(value, loaded)
loaded = newLoaded;
const progress = (loaded / total) * 100;
// Call your function here
progressCallback({
progress: progress,
loaded: loaded,
total: total,
})
return read();
}
// Actually read
await read();
return buffer;
}
function pathJoin(...parts) {
// https://stackoverflow.com/a/55142565
parts = parts.map((part, index) => {
if (index) {
part = part.replace(new RegExp('^/'), '');
}
if (index !== parts.length - 1) {
part = part.replace(new RegExp('/$'), '');
}
return part;
})
return parts.join('/');
}
function reverseDictionary(data) {
// https://ultimatecourses.com/blog/reverse-object-keys-and-values-in-javascript
return Object.fromEntries(Object.entries(data).map(([key, value]) => [value, key]));
}
function indexOfMax(arr) {
// https://stackoverflow.com/a/11301464
if (arr.length === 0) {
return -1;
}
var max = arr[0];
var maxIndex = 0;
for (var i = 1; i < arr.length; ++i) {
if (arr[i] > max) {
maxIndex = i;
max = arr[i];
}
}
return maxIndex;
}
function softmax(arr) {
// Compute the maximum value in the array
const max = Math.max(...arr);
// Compute the exponentials of the array values
const exps = arr.map(x => Math.exp(x - max));
// Compute the sum of the exponentials
const sumExps = exps.reduce((acc, val) => acc + val, 0);
// Compute the softmax values
const softmaxArr = exps.map(x => x / sumExps);
return softmaxArr;
}
function log_softmax(arr) {
// Compute the softmax values
const softmaxArr = softmax(arr);
// Apply log formula to each element
const logSoftmaxArr = softmaxArr.map(x => Math.log(x));
return logSoftmaxArr;
}
function escapeRegExp(string) {
return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string
}
function getTopItems(items, top_k = 0) {
// if top == 0, return all
items = Array.from(items)
.map((x, i) => [i, x]) // Get indices ([index, score])
.sort((a, b) => b[1] - a[1]) // Sort by log probabilities
if (top_k > 0) {
items = items.slice(0, top_k); // Get top k items
}
return items
}
function dot(arr1, arr2) {
return arr1.reduce((acc, val, i) => acc + val * arr2[i], 0);
}
function cos_sim(arr1, arr2) {
// Calculate dot product of the two arrays
const dotProduct = dot(arr1, arr2);
// Calculate the magnitude of the first array
const magnitudeA = magnitude(arr1);
// Calculate the magnitude of the second array
const magnitudeB = magnitude(arr2);
// Calculate the cosine similarity
const cosineSimilarity = dotProduct / (magnitudeA * magnitudeB);
return cosineSimilarity;
}
function magnitude(arr) {
return Math.sqrt(arr.reduce((acc, val) => acc + val * val, 0));
}
class Callable extends Function {
constructor() {
let closure = function (...args) { return closure._call(...args) }
return Object.setPrototypeOf(closure, new.target.prototype)
}
_call(...args) {
throw Error('Must implement _call method in subclass')
}
}
function isString(text) {
return typeof text === 'string' || text instanceof String
}
function isIntegralNumber(x) {
return Number.isInteger(x) || typeof x === 'bigint'
}
function exists(x) {
return x !== undefined && x !== null;
}
module.exports = {
Callable,
getModelFile,
dispatchCallback,
fetchJSON,
pathJoin,
reverseDictionary,
indexOfMax,
softmax,
log_softmax,
escapeRegExp,
getTopItems,
dot,
cos_sim,
magnitude,
getFile,
isIntegralNumber,
isString,
exists
};