|
5 | 5 | "encoding/json" |
6 | 6 | "errors" |
7 | 7 | "fmt" |
| 8 | + "io" |
| 9 | + "net/http" |
| 10 | + "net/url" |
8 | 11 | "os" |
9 | 12 | "os/signal" |
10 | 13 | "path/filepath" |
@@ -341,24 +344,65 @@ func writeOutput(outputPath string, output []byte) error { |
341 | 344 | } |
342 | 345 |
|
343 | 346 | func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error { |
344 | | - dataurlObj, err := dataurl.DecodeString(outputString) |
345 | | - if err != nil { |
346 | | - return fmt.Errorf("Failed to decode dataurl: %w", err) |
| 347 | + var output []byte |
| 348 | + var contentType string |
| 349 | + |
| 350 | + if httpURL, ok := getHTTPURL(outputString); ok { |
| 351 | + resp, err := http.Get(httpURL.String()) |
| 352 | + if err != nil { |
| 353 | + return fmt.Errorf("Failed to fetch URL: %w", err) |
| 354 | + } |
| 355 | + defer resp.Body.Close() |
| 356 | + |
| 357 | + output, err = io.ReadAll(resp.Body) |
| 358 | + if err != nil { |
| 359 | + return fmt.Errorf("Failed to read response: %w", err) |
| 360 | + } |
| 361 | + contentType = resp.Header.Get("Content-Type") |
| 362 | + contentType = useExtensionIfUnknownContentType(contentType, output, outputString) |
| 363 | + |
| 364 | + } else { |
| 365 | + dataurlObj, err := dataurl.DecodeString(outputString) |
| 366 | + if err != nil { |
| 367 | + return fmt.Errorf("Failed to decode dataurl: %w", err) |
| 368 | + } |
| 369 | + output = dataurlObj.Data |
| 370 | + contentType = dataurlObj.ContentType() |
347 | 371 | } |
348 | | - output := dataurlObj.Data |
349 | 372 |
|
350 | 373 | if addExtension { |
351 | | - extension := mime.ExtensionByType(dataurlObj.ContentType()) |
352 | | - if extension != "" { |
353 | | - outputPath += extension |
| 374 | + if ext := mime.ExtensionByType(contentType); ext != "" { |
| 375 | + outputPath += ext |
354 | 376 | } |
355 | 377 | } |
356 | 378 |
|
357 | | - if err := writeOutput(outputPath, output); err != nil { |
358 | | - return err |
| 379 | + return writeOutput(outputPath, output) |
| 380 | +} |
| 381 | + |
| 382 | +func getHTTPURL(str string) (*url.URL, bool) { |
| 383 | + u, err := url.Parse(str) |
| 384 | + if err == nil && (u.Scheme == "http" || u.Scheme == "https") { |
| 385 | + return u, true |
359 | 386 | } |
| 387 | + return nil, false |
| 388 | +} |
360 | 389 |
|
361 | | - return nil |
| 390 | +func useExtensionIfUnknownContentType(contentType string, content []byte, filename string) string { |
| 391 | + // If contentType is empty or application/octet-string, first attempt to get the |
| 392 | + // content type from the file extension, and if that fails, try to guess it from |
| 393 | + // the content itself. |
| 394 | + |
| 395 | + if contentType == "" || contentType == "application/octet-stream" { |
| 396 | + if ext := filepath.Ext(filename); ext != "" { |
| 397 | + if mimeType := mime.TypeByExtension(ext); mimeType != "" { |
| 398 | + return mimeType |
| 399 | + } |
| 400 | + } |
| 401 | + if detected := http.DetectContentType(content); detected != "" { |
| 402 | + return detected |
| 403 | + } |
| 404 | + } |
| 405 | + return contentType |
362 | 406 | } |
363 | 407 |
|
364 | 408 | func parseInputFlags(inputs []string) (predict.Inputs, error) { |
|
0 commit comments