diff --git a/modules/__tests__/withRouter-test.js b/modules/__tests__/withRouter-test.js
index d276ce28fd..06ade17b4b 100644
--- a/modules/__tests__/withRouter-test.js
+++ b/modules/__tests__/withRouter-test.js
@@ -4,21 +4,17 @@ import { render, unmountComponentAtNode } from 'react-dom'
import createHistory from '../createMemoryHistory'
import Route from '../Route'
import Router from '../Router'
-import routerShape from '../PropTypes'
import withRouter from '../withRouter'
describe('withRouter', function () {
- class App extends Component {
- propTypes: {
- router: routerShape.isRequired
- }
- testFunction() {
- return 'hello from the test function'
- }
- render() {
- expect(this.props.router).toExist()
- return
App
- }
+ const routerStub = {
+ push() {},
+ replace() {},
+ go() {},
+ goBack() {},
+ goForward() {},
+ setRouteLeaveHook() {},
+ isActive() {}
}
let node
@@ -30,51 +26,63 @@ describe('withRouter', function () {
unmountComponentAtNode(node)
})
- it('puts router on context', function (done) {
- const WrappedApp = withRouter(App)
+ it('should put router on props', function (done) {
+ const MyComponent = withRouter(({ router }) => {
+ expect(router).toExist()
+ done()
+ return null
+ })
+
+ function App() {
+ return // Ensure no props are passed explicitly.
+ }
render((
-
+
- ), node, function () {
- done()
- })
+ ), node)
})
- it('still uses router prop if provided', function (done) {
- const Test = withRouter(function (props) {
- props.test(props)
+ it('should set displayName', function () {
+ function MyComponent() {
return null
- })
- const router = {
- push() {},
- replace() {},
- go() {},
- goBack() {},
- goForward() {},
- setRouteLeaveHook() {},
- isActive() {}
- }
- const test = function (props) {
- expect(props.router).toBe(router)
}
- render(, node, done)
+ MyComponent.displayName = 'MyComponent'
+
+ expect(withRouter(MyComponent).displayName)
+ .toEqual('withRouter(MyComponent)')
+ })
+
+ it('should use router prop if specified', function (done) {
+ const MyComponent = withRouter(({ router }) => {
+ expect(router).toBe(routerStub)
+ done()
+ return null
+ })
+
+ render(, node)
})
- it('should support withRefs as a parameter', function (done) {
- const WrappedApp = withRouter(App, { withRef: true })
- const router = {
- push() {},
- replace() {},
- go() {},
- goBack() {},
- goForward() {},
- setRouteLeaveHook() {},
- isActive() {}
+ it('should support withRef', function () {
+ const spy = expect.createSpy()
+
+ class MyComponent extends Component {
+ invokeSpy() {
+ spy()
+ }
+
+ render() {
+ return null
+ }
}
- const component = render((), node, done)
- expect(component.getWrappedInstance().testFunction()).toEqual('hello from the test function')
+
+ const WrappedComponent = withRouter(MyComponent, { withRef: true })
+
+ const instance = render(, node)
+ instance.getWrappedInstance().invokeSpy()
+
+ expect(spy).toHaveBeenCalled()
})
})
diff --git a/modules/withRouter.js b/modules/withRouter.js
index d5c6eb2b2c..b56fdc0946 100644
--- a/modules/withRouter.js
+++ b/modules/withRouter.js
@@ -1,31 +1,38 @@
+import invariant from 'invariant'
import React from 'react'
import hoistStatics from 'hoist-non-react-statics'
import { routerShape } from './PropTypes'
-import warning from './routerWarning'
-
function getDisplayName(WrappedComponent) {
return WrappedComponent.displayName || WrappedComponent.name || 'Component'
}
export default function withRouter(WrappedComponent, options) {
- const { withRef } = options || {}
+ const withRef = options && options.withRef
const WithRouter = React.createClass({
contextTypes: { router: routerShape },
propTypes: { router: routerShape },
getWrappedInstance() {
- warning(withRef, 'To access the wrappedInstance you must provide { withRef: true } as the second argument of the withRouter call')
- return this.wrappedComponent
+ invariant(
+ withRef,
+ 'To access the wrapped instance, you need to specify ' +
+ '`{ withRef: true }` as the second argument of the withRouter() call.'
+ )
+
+ return this.wrappedInstance
},
render() {
- const { router, ...props } = this.props
+ const router = this.props.router || this.context.router
+ const props = { ...this.props, router }
- if (withRef) props.ref = component =>this.wrappedComponent = component
+ if (withRef) {
+ props.ref = (c) => { this.wrappedInstance = c }
+ }
- return
+ return
}
})