Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(utils): make derive glitch free #335

Merged
merged 6 commits into from
Jan 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 56 additions & 16 deletions src/utils/derive.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ import { getVersion, proxy, subscribe } from '../vanilla'

type DeriveGet = <T extends object>(proxyObject: T) => T

type Subscriptions<U extends object> = Map<
object,
[callbackMap: Map<keyof U, () => void>, unsubscribe: () => void]
>
type Subscription<U extends object> = [
callbackMap: Map<keyof U, () => void>,
unsubscribe: () => void
]

type Subscriptions<U extends object> = Map<object, Subscription<U>>

const subscriptionsCache = new WeakMap<object, Subscriptions<object>>()

Expand All @@ -24,6 +26,25 @@ const getSubscriptions = (proxyObject: object) => {
// It's not expected to use this in production.
export const unstable_getDeriveSubscriptions = getSubscriptions

// to make derive glitch free: https://github.com/pmndrs/valtio/pull/335
const pendingCountMap = new WeakMap<object, number>()
const markPending = (proxyObject: object) => {
const count = pendingCountMap.get(proxyObject) || 0
pendingCountMap.set(proxyObject, count + 1)
}
const unmarkPending = (proxyObject: object) => {
const count = pendingCountMap.get(proxyObject) || 0
if (count > 1) {
pendingCountMap.set(proxyObject, count - 1)
} else {
pendingCountMap.delete(proxyObject)
}
}
const isPending = (proxyObject: object) => {
const count = pendingCountMap.get(proxyObject) || 0
return count > 0
}

/**
* derive
*
Expand Down Expand Up @@ -61,10 +82,11 @@ export const derive = <T extends object, U extends object>(
const notifyInSync = options?.sync
const subscriptions: Subscriptions<U> = getSubscriptions(proxyObject)
const addSubscription = (p: object, key: keyof U, callback: () => void) => {
const subscription = subscriptions.get(p)
if (subscription) {
subscription[0].set(key, callback)
} else {
let subscription = subscriptions.get(p)
if (!subscription) {
const notify = () =>
(subscription as Subscription<U>)[0].forEach((cb) => cb())
let promise: Promise<void> | undefined
const unsubscribe = subscribe(
p,
(ops) => {
Expand All @@ -77,14 +99,28 @@ export const derive = <T extends object, U extends object>(
// only setting derived properties
return
}
subscriptions.get(p)?.[0].forEach((cb) => {
cb()
})
if (promise) {
// already scheduled
return
}
markPending(p)
if (notifyInSync) {
unmarkPending(p)
notify()
} else {
promise = Promise.resolve().then(() => {
promise = undefined
unmarkPending(p)
notify()
})
}
},
notifyInSync
true
)
subscriptions.set(p, [new Map([[key, callback]]), unsubscribe])
subscription = [new Map(), unsubscribe]
subscriptions.set(p, subscription)
}
subscription[0].set(key, callback)
}
const removeSubscription = (p: object, key: keyof U) => {
const subscription = subscriptions.get(p)
Expand All @@ -105,6 +141,10 @@ export const derive = <T extends object, U extends object>(
let lastDependencies: Map<object, number> | null = null
const evaluate = () => {
if (lastDependencies) {
if (Array.from(lastDependencies).some(([p]) => isPending(p))) {
// some dependencies are still pending
return
}
if (
Array.from(lastDependencies).every(([p, n]) => getVersion(p) === n)
) {
Expand All @@ -118,7 +158,7 @@ export const derive = <T extends object, U extends object>(
return p
}
const value = fn(get)
const subscribe = () => {
const subscribeToDependencies = () => {
dependencies.forEach((_, p) => {
if (!lastDependencies?.has(p)) {
addSubscription(p, key, evaluate)
Expand All @@ -132,9 +172,9 @@ export const derive = <T extends object, U extends object>(
lastDependencies = dependencies
}
if (value instanceof Promise) {
value.finally(subscribe)
value.finally(subscribeToDependencies)
} else {
subscribe()
subscribeToDependencies()
}
proxyObject[key] = value
}
Expand Down
134 changes: 133 additions & 1 deletion tests/derive.test.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ it('nested emulation with derive', async () => {
{
doubled: (get) => computeDouble(get(state.math).count),
},
{ proxy: state.math }
{ proxy: state.math, sync: true }
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is technically a behavioral change, but hope it's not too critical.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So if it's async it can't be glitch free?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but I didn't confirm it in the nested scenario. Without sync in this case, it simply double call the subscriber. It's a side effect of supporting glitch free.

)

const callback = jest.fn()
Expand Down Expand Up @@ -299,3 +299,135 @@ it('basic underive', async () => {
await Promise.resolve()
expect(callback).toBeCalledTimes(1)
})

describe('glitch free', () => {
it('basic (#296)', async () => {
const state = proxy({ value: 0 })
const derived1 = derive({
value: (get) => get(state).value,
})
const derived2 = derive({
value: (get) => get(derived1).value,
})
const computeValue = jest.fn((get) => {
const v0 = get(state).value
const v1 = get(derived1).value
const v2 = get(derived2).value
return v0 + (v1 - v2)
})
const derived3 = derive({
value: (get) => computeValue(get),
})

const App = () => {
const snap = useSnapshot(derived3)
return (
<div>
value: {snap.value}
<button onClick={() => ++state.value}>button</button>
</div>
)
}

const { getByText, findByText } = render(
<StrictMode>
<App />
</StrictMode>
)

await findByText('value: 0')
expect(computeValue).toBeCalledTimes(1)

fireEvent.click(getByText('button'))
await findByText('value: 1')
expect(computeValue).toBeCalledTimes(2)
})

it('same value', async () => {
const state = proxy({ value: 0 })
const derived1 = derive({
value: (get) => get(state).value * 0,
})
const derived2 = derive({
value: (get) => get(derived1).value * 0,
})
const computeValue = jest.fn((get) => {
const v0 = get(state).value
const v1 = get(derived1).value
const v2 = get(derived2).value
return v0 + (v1 - v2)
})
const derived3 = derive({
value: (get) => computeValue(get),
})

const App = () => {
const snap = useSnapshot(derived3)
return (
<div>
value: {snap.value}
<button onClick={() => ++state.value}>button</button>
</div>
)
}

const { getByText, findByText } = render(
<StrictMode>
<App />
</StrictMode>
)

await findByText('value: 0')
expect(computeValue).toBeCalledTimes(1)

fireEvent.click(getByText('button'))
await findByText('value: 1')
expect(computeValue).toBeCalledTimes(2)
})

it('double chain', async () => {
const state = proxy({ value: 0 })
const derived1 = derive({
value: (get) => get(state).value,
})
const derived2 = derive({
value: (get) => get(derived1).value,
})
const derived3 = derive({
value: (get) => get(derived2).value,
})
const computeValue = jest.fn((get) => {
const v0 = get(state).value
const v1 = get(derived1).value
const v2 = get(derived2).value
const v3 = get(derived3).value
return v0 + (v1 - v2) + v3 * 0
})
const derived4 = derive({
value: (get) => computeValue(get),
})

const App = () => {
const snap = useSnapshot(derived4)
return (
<div>
value: {snap.value}
<button onClick={() => ++state.value}>button</button>
</div>
)
}

const { getByText, findByText } = render(
<StrictMode>
<App />
</StrictMode>
)

await findByText('value: 0')
expect(computeValue).toBeCalledTimes(1)

fireEvent.click(getByText('button'))
await findByText('value: 1')
expect(computeValue).toBeCalledTimes(2)
})
})