diff --git a/include/net/scm.h b/include/net/scm.h index 745460fa2f02cd..d456f4c71a323b 100644 --- a/include/net/scm.h +++ b/include/net/scm.h @@ -49,7 +49,7 @@ static __inline__ void scm_set_cred(struct scm_cookie *scm, struct pid *pid, const struct cred *cred) { scm->pid = get_pid(pid); - scm->cred = get_cred(cred); + scm->cred = cred ? get_cred(cred) : NULL; cred_to_ucred(pid, cred, &scm->creds); } @@ -73,8 +73,7 @@ static __inline__ void scm_destroy(struct scm_cookie *scm) static __inline__ int scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *scm) { - scm_set_cred(scm, task_tgid(current), current_cred()); - scm->fp = NULL; + memset(scm, 0, sizeof(*scm)); unix_get_peersec_dgram(sock, scm); if (msg->msg_controllen <= 0) return 0; diff --git a/net/core/scm.c b/net/core/scm.c index 811b53fb330e8a..ff52ad0a51501c 100644 --- a/net/core/scm.c +++ b/net/core/scm.c @@ -173,7 +173,7 @@ int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p) if (err) goto error; - if (pid_vnr(p->pid) != p->creds.pid) { + if (!p->pid || pid_vnr(p->pid) != p->creds.pid) { struct pid *pid; err = -ESRCH; pid = find_get_pid(p->creds.pid); @@ -183,8 +183,9 @@ int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p) p->pid = pid; } - if ((p->cred->euid != p->creds.uid) || - (p->cred->egid != p->creds.gid)) { + if (!p->cred || + (p->cred->euid != p->creds.uid) || + (p->cred->egid != p->creds.gid)) { struct cred *cred; err = -ENOMEM; cred = prepare_creds(); @@ -193,7 +194,8 @@ int __scm_send(struct socket *sock, struct msghdr *msg, struct scm_cookie *p) cred->uid = cred->euid = p->creds.uid; cred->gid = cred->egid = p->creds.gid; - put_cred(p->cred); + if (p->cred) + put_cred(p->cred); p->cred = cred; } break; diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 4330db99fabff3..1201b6d4183d89 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -1324,10 +1324,9 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, if (msg->msg_flags&MSG_OOB) return -EOPNOTSUPP; - if (NULL == siocb->scm) { + if (NULL == siocb->scm) siocb->scm = &scm; - memset(&scm, 0, sizeof(scm)); - } + err = scm_send(sock, msg, siocb->scm); if (err < 0) return err; diff --git a/net/unix/af_unix.c b/net/unix/af_unix.c index ec68e1c05b85ee..466fbcc5cf77a9 100644 --- a/net/unix/af_unix.c +++ b/net/unix/af_unix.c @@ -1381,8 +1381,10 @@ static int unix_attach_fds(struct scm_cookie *scm, struct sk_buff *skb) static int unix_scm_to_skb(struct scm_cookie *scm, struct sk_buff *skb, bool send_fds) { int err = 0; + UNIXCB(skb).pid = get_pid(scm->pid); - UNIXCB(skb).cred = get_cred(scm->cred); + if (scm->cred) + UNIXCB(skb).cred = get_cred(scm->cred); UNIXCB(skb).fp = NULL; if (scm->fp && send_fds) err = unix_attach_fds(scm, skb); @@ -1391,6 +1393,24 @@ static int unix_scm_to_skb(struct scm_cookie *scm, struct sk_buff *skb, bool sen return err; } +/* + * Some apps rely on write() giving SCM_CREDENTIALS + * We include credentials if source or destination socket + * asserted SOCK_PASSCRED. + */ +static void maybe_add_creds(struct sk_buff *skb, const struct socket *sock, + const struct sock *other) +{ + if (UNIXCB(skb).cred) + return; + if (test_bit(SOCK_PASSCRED, &sock->flags) || + !other->sk_socket || + test_bit(SOCK_PASSCRED, &other->sk_socket->flags)) { + UNIXCB(skb).pid = get_pid(task_tgid(current)); + UNIXCB(skb).cred = get_current_cred(); + } +} + /* * Send AF_UNIX data. */ @@ -1538,6 +1558,7 @@ static int unix_dgram_sendmsg(struct kiocb *kiocb, struct socket *sock, if (sock_flag(other, SOCK_RCVTSTAMP)) __net_timestamp(skb); + maybe_add_creds(skb, sock, other); skb_queue_tail(&other->sk_receive_queue, skb); if (max_level > unix_sk(other)->recursion_level) unix_sk(other)->recursion_level = max_level; @@ -1652,6 +1673,7 @@ static int unix_stream_sendmsg(struct kiocb *kiocb, struct socket *sock, (other->sk_shutdown & RCV_SHUTDOWN)) goto pipe_err_free; + maybe_add_creds(skb, sock, other); skb_queue_tail(&other->sk_receive_queue, skb); if (max_level > unix_sk(other)->recursion_level) unix_sk(other)->recursion_level = max_level;