diff --git a/net/core/skmsg.c b/net/core/skmsg.c index b088fe07fc00a..7e7205e93258b 100644 --- a/net/core/skmsg.c +++ b/net/core/skmsg.c @@ -613,23 +613,42 @@ static void sock_drop(struct sock *sk, struct sk_buff *skb) kfree_skb(skb); } +static void sk_psock_skb_state(struct sk_psock *psock, + struct sk_psock_work_state *state, + struct sk_buff *skb, + int len, int off) +{ + spin_lock_bh(&psock->ingress_lock); + if (sk_psock_test_state(psock, SK_PSOCK_TX_ENABLED)) { + state->skb = skb; + state->len = len; + state->off = off; + } else { + sock_drop(psock->sk, skb); + } + spin_unlock_bh(&psock->ingress_lock); +} + static void sk_psock_backlog(struct work_struct *work) { struct sk_psock *psock = container_of(work, struct sk_psock, work); struct sk_psock_work_state *state = &psock->work_state; - struct sk_buff *skb; + struct sk_buff *skb = NULL; bool ingress; u32 len, off; int ret; mutex_lock(&psock->work_mutex); - if (state->skb) { + if (unlikely(state->skb)) { + spin_lock_bh(&psock->ingress_lock); skb = state->skb; len = state->len; off = state->off; state->skb = NULL; - goto start; + spin_unlock_bh(&psock->ingress_lock); } + if (skb) + goto start; while ((skb = skb_dequeue(&psock->ingress_skb))) { len = skb->len; @@ -644,9 +663,8 @@ static void sk_psock_backlog(struct work_struct *work) len, ingress); if (ret <= 0) { if (ret == -EAGAIN) { - state->skb = skb; - state->len = len; - state->off = off; + sk_psock_skb_state(psock, state, skb, + len, off); goto end; } /* Hard errors break pipe and stop xmit. */ @@ -745,6 +763,11 @@ static void __sk_psock_zap_ingress(struct sk_psock *psock) skb_bpf_redirect_clear(skb); sock_drop(psock->sk, skb); } + kfree_skb(psock->work_state.skb); + /* We null the skb here to ensure that calls to sk_psock_backlog + * do not pick up the free'd skb. + */ + psock->work_state.skb = NULL; __sk_psock_purge_ingress_msg(psock); }