Skip to content

Commit

Permalink
Implement more feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
Murderlon committed Sep 19, 2023
1 parent 9a0ceb8 commit 80f6e3c
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 139 deletions.
195 changes: 95 additions & 100 deletions lib/safe-http-client.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import dns from 'node:dns/promises'

import {fetch} from 'undici'
import ipaddr from 'ipaddr.js'

import {defaultMimeTypes} from './constants.js'

export class HttpError extends Error {
Expand All @@ -28,122 +26,119 @@ export class HttpError extends Error {
}
}

export class SafeHttpClient {
/** @param {number} [maxSize] */
constructor(maxSize) {
this.maxSize = maxSize
/**
* Check that the URL is valid, the prototol is allowed, and that the host is
* a safe unicast address.
*
* @param {string} url
* URL to check.
* @returns {Promise<void>}
* URL object.
*/
export async function checkUrl(url) {
// Throws if the URL is invalid
const validUrl = new URL(url)
const {protocol, hostname} = validUrl

// Don't allow aother protocols like file:// URLs
if (!['http:', 'https:'].includes(protocol)) {
throw new Error('Bad protocol')
}

/**
* Check if the URL is a valid URL or IP, that the prototol is valid,
* and the host is a safe unicast address.
* @param {string} url
*/
static async checkUrl(url) {
// Throws if the URL is invalid
const validUrl = new URL(url)
const {protocol, hostname} = validUrl

// Don't allow aother protocols like file:// URLs
if (!['http:', 'https:'].includes(protocol)) {
throw new Error('Bad protocol')
}

try {
var {address} = await dns.lookup(hostname)
} catch (err) {
throw new Error('Bad url host')
}

/**
* Server Side Request Forgery (SSRF) Protection.
*
* SSRF is an attack where an attacker can trick a server into making unexpected network connections.
* This can lead to unauthorized access to internal resources, information disclosure,
* denial-of-service attacks, or even remote code execution.
*
* One common SSRF vector is tricking the server into making requests to internal IP addresses
* or to other services within the network that the server shouldn't be accessing. This can
* expose sensitive internal data or systems.
*
* Unicast addresses are typically used for communication between hosts on the public internet.
* By only allowing addresses in the 'unicast' range, we can prevent SSRF attacks targeting
* non-public IP ranges, such as private, multicast, and reserved IPs.
*/
if (ipaddr.process(address).range() !== 'unicast') {
throw new Error('Bad url host')
}

return validUrl
try {
var {address} = await dns.lookup(hostname)
} catch (err) {
throw new Error('Bad url host')
}

/**
* Fetch a URL.
* Server Side Request Forgery (SSRF) Protection.
*
* SSRF is an attack where an attacker can trick a server into making unexpected network connections.
* This can lead to unauthorized access to internal resources, information disclosure,
* denial-of-service attacks, or even remote code execution.
*
* One common SSRF vector is tricking the server into making requests to internal IP addresses
* or to other services within the network that the server shouldn't be accessing. This can
* expose sensitive internal data or systems.
*
* @param {URL | string} url
* URL.
* @param {import('undici').RequestInit} options
* Configuration, passed through to `fetch`.
* @returns {Promise<{buffer?: Buffer, headers: import('undici').Headers}>}
* Buffer of response (except when `HEAD`) and headers.
* Unicast addresses are typically used for communication between hosts on the public internet.
* By only allowing addresses in the 'unicast' range, we can prevent SSRF attacks targeting
* non-public IP ranges, such as private, multicast, and reserved IPs.
*/
async safeFetch(url, options) {
let response = await fetch(url, options)

// If there's a redirect, check the redirected URL for SSRF and then follow it if it's valid.
if ([301, 302, 303, 307, 308].includes(response.status)) {
const redirectedUrl = response.headers.get('location')
if (ipaddr.process(address).range() !== 'unicast') {
throw new Error('Bad url host')
}
}

if (!redirectedUrl) {
throw new HttpError(400, 'Missing `Location` header')
}
/**
* Fetch a URL.
*
* @param {URL | string} url
* URL.
* @param {import('undici').RequestInit} options
* Configuration, passed through to `fetch`.
* @param {number} [maxSize]
* The max size in bytes to download.
* @returns {Promise<{buffer?: Buffer, headers: import('undici').Headers}>}
* Buffer of response (except when `HEAD`) and headers.
*/
export async function safeFetch(url, options, maxSize) {
let response = await fetch(url, options)

// If there's a redirect, check the redirected URL for SSRF and then follow it if it's valid.
if ([301, 302, 303, 307, 308].includes(response.status)) {
const redirectedUrl = response.headers.get('location')

if (!redirectedUrl) {
throw new HttpError(400, 'Missing `Location` header')
}

await SafeHttpClient.checkUrl(redirectedUrl)
await checkUrl(redirectedUrl)

response = await fetch(redirectedUrl, {
...options,
// Do not allow another redirect
redirect: 'error'
})
}
response = await fetch(redirectedUrl, {
...options,
// Do not allow another redirect
redirect: 'error'
})
}

const contentType = response.headers.get('content-type')
if (!contentType) {
throw new HttpError(400, 'Empty content-type header')
}
const contentType = response.headers.get('content-type')
if (!contentType) {
throw new HttpError(400, 'Empty content-type header')
}

if (!defaultMimeTypes.includes(contentType)) {
throw new HttpError(400, 'Unsupported content-type returned')
}
if (!defaultMimeTypes.includes(contentType)) {
throw new HttpError(400, 'Unsupported content-type returned')
}

if (options.method === 'HEAD') {
return {headers: response.headers}
}
if (options.method === 'HEAD') {
return {headers: response.headers}
}

if (!response.body) {
throw new HttpError(400, 'No response body')
}
if (!response.body) {
throw new HttpError(400, 'No response body')
}

/** @type {Array<Buffer>} */
const chunks = []
const reader = response.body.getReader()
let currentByteLength = 0
/** @type {Array<Buffer>} */
const chunks = []
const reader = response.body.getReader()
let currentByteLength = 0

while (true) {
const {done, value} = await reader.read()
if (done) {
break
}
chunks.push(value)
while (true) {
const {done, value} = await reader.read()
if (done) {
break
}
chunks.push(value)

if (this.maxSize) {
currentByteLength += value.length
if (currentByteLength > this.maxSize) {
throw new HttpError(413, 'Content-Length exceeded')
}
if (maxSize) {
currentByteLength += value.length
if (currentByteLength > maxSize) {
throw new HttpError(413, 'Content-Length exceeded')
}
}

return {buffer: Buffer.concat(chunks), headers: response.headers}
}

return {buffer: Buffer.concat(chunks), headers: response.headers}
}
52 changes: 27 additions & 25 deletions lib/server.js
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import http from 'node:http'
import net from 'node:net'
import {EventEmitter} from 'node:events'
import crypto from 'node:crypto'
import url from 'node:url'

import http from 'node:http'
import net from 'node:net'
import {Headers} from 'undici'

import {SafeHttpClient, HttpError} from './safe-http-client.js'
import {checkUrl, safeFetch, HttpError} from './safe-http-client.js'
import {
securityHeaders,
defaultRequestHeaders,
Expand Down Expand Up @@ -64,9 +61,12 @@ export class Server extends EventEmitter {
}

/**
* Start the server. Identical to `net.Server.listen()`.
* Start the server.
*
* @param {Parameters<InstanceType<typeof net.Server>['listen']>} args
* @public
* Arguments passedf to `net.Server.listen`.
* @returns {net.Server}
* Server.
*/
listen(...args) {
return http.createServer(this.handle.bind(this)).listen(...args)
Expand All @@ -77,14 +77,13 @@ export class Server extends EventEmitter {
* Integrate with your own server by calling this method and routing all requests to it.
* @param {http.IncomingMessage} req
* @param {http.ServerResponse} res
* @public
*/
async handle(req, res) {
if (req.method !== 'GET' && req.method !== 'HEAD') {
return this.write(res, 405, 'Method not allowed')
}

const paths = url.parse(req.url || '')?.path?.split('/')
const paths = req.url?.split('/')

if (!paths || paths.length < 3) {
return this.write(res, 404, 'Malformed request')
Expand All @@ -98,7 +97,7 @@ export class Server extends EventEmitter {
}

try {
var validUrl = await SafeHttpClient.checkUrl(decodedUrl)
await checkUrl(decodedUrl)
} catch (err) {
const exception = /** @type {Error} */ (err)
return this.write(res, 400, exception.message)
Expand All @@ -115,17 +114,20 @@ export class Server extends EventEmitter {
// TODO: respect forwarded headers (check if not private IP)
const filterRequestHeaders = filterHeaders(defaultRequestHeaders)
const filterResponseHeaders = filterHeaders(defaultResponseHeaders)
const client = new SafeHttpClient(this.options.maxSize)
const {buffer, headers: resHeaders} = await client.safeFetch(validUrl, {
// @ts-expect-error: `IncomingHttpHeaders` can be passed to `Headers`
headers: filterRequestHeaders(new Headers(req.headers)),
method: req.method,
// We can't blindly follow redirects as the initial checkUrl
// might have been safe, but the redirect location might not be.
// SafeHttpClient will check the redirect location before following it.
redirect: 'manual',
signal
})
const {buffer, headers: resHeaders} = await safeFetch(
decodedUrl,
{
// @ts-expect-error: `IncomingHttpHeaders` can be passed to `Headers`
headers: filterRequestHeaders(new Headers(req.headers)),
method: req.method,
// We can't blindly follow redirects as the initial checkUrl
// might have been safe, but the redirect location might not be.
// safeFetch will check the redirect location before following it.
redirect: 'manual',
signal
},
this.options.maxSize
)

const headers = {
...securityHeaders,
Expand All @@ -148,8 +150,8 @@ export class Server extends EventEmitter {
if (err.name === 'AbortError') {
return
}
const msg = err.message || 'Internal server error'
return this.write(res, 500, msg)
console.error(err)
return this.write(res, 500, 'Internal server error')
}
}
}
Expand All @@ -161,7 +163,7 @@ export class Server extends EventEmitter {
*/
verifyHmac(receivedDigest, hex) {
// Hex-decode the URL
const decodedUrl = Buffer.from(hex, 'hex').toString()
const decodedUrl = String(Buffer.from(hex, 'hex'))

// Verify the HMAC digest to ensure the URL hasn't been tampered with
const hmac = crypto.createHmac('sha1', this.options.secret)
Expand Down
Loading

0 comments on commit 80f6e3c

Please sign in to comment.