Skip to content

Commit

Permalink
IB/cm: Split cm_alloc_msg()
Browse files Browse the repository at this point in the history
[ Upstream commit 4b4e586 ]

This is being used with two quite different flows, one attaches the
message to the priv and the other does not.

Ensure the message attach is consistently done under the spinlock and
ensure that the free on error always detaches the message from the
cm_id_priv, also always under lock.

This makes read/write to the cm_id_priv->msg consistently locked and
consistently NULL'd when the message is freed, even in all error paths.

Link: https://lore.kernel.org/r/f692b8c89eecb34fd82244f317e478bea6c97688.1622629024.git.leonro@nvidia.com
Signed-off-by: Mark Zhang <markzhang@nvidia.com>
Signed-off-by: Leon Romanovsky <leonro@nvidia.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
Signed-off-by: Sasha Levin <sashal@kernel.org>
  • Loading branch information
jgunthorpe authored and gregkh committed Jul 14, 2021
1 parent ad6608c commit 020155e
Showing 1 changed file with 115 additions and 75 deletions.
190 changes: 115 additions & 75 deletions drivers/infiniband/core/cm.c
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,7 @@ static inline void cm_deref_id(struct cm_id_private *cm_id_priv)
complete(&cm_id_priv->comp);
}

static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
struct ib_mad_send_buf **msg)
static struct ib_mad_send_buf *cm_alloc_msg(struct cm_id_private *cm_id_priv)
{
struct ib_mad_agent *mad_agent;
struct ib_mad_send_buf *m;
Expand Down Expand Up @@ -359,12 +358,42 @@ static int cm_alloc_msg(struct cm_id_private *cm_id_priv,
m->retries = cm_id_priv->max_cm_retries;

refcount_inc(&cm_id_priv->refcount);
spin_unlock_irqrestore(&cm.state_lock, flags2);
m->context[0] = cm_id_priv;
*msg = m;
return m;

out:
spin_unlock_irqrestore(&cm.state_lock, flags2);
return ret;
return ERR_PTR(ret);
}

static struct ib_mad_send_buf *
cm_alloc_priv_msg(struct cm_id_private *cm_id_priv)
{
struct ib_mad_send_buf *msg;

lockdep_assert_held(&cm_id_priv->lock);

msg = cm_alloc_msg(cm_id_priv);
if (IS_ERR(msg))
return msg;
cm_id_priv->msg = msg;
return msg;
}

static void cm_free_priv_msg(struct ib_mad_send_buf *msg)
{
struct cm_id_private *cm_id_priv = msg->context[0];

lockdep_assert_held(&cm_id_priv->lock);

if (!WARN_ON(cm_id_priv->msg != msg))
cm_id_priv->msg = NULL;

if (msg->ah)
rdma_destroy_ah(msg->ah, 0);
cm_deref_id(cm_id_priv);
ib_free_send_mad(msg);
}

static struct ib_mad_send_buf *cm_alloc_response_msg_no_ah(struct cm_port *port,
Expand Down Expand Up @@ -1508,6 +1537,7 @@ int ib_send_cm_req(struct ib_cm_id *cm_id,
struct ib_cm_req_param *param)
{
struct cm_id_private *cm_id_priv;
struct ib_mad_send_buf *msg;
struct cm_req_msg *req_msg;
unsigned long flags;
int ret;
Expand Down Expand Up @@ -1559,31 +1589,34 @@ int ib_send_cm_req(struct ib_cm_id *cm_id,
cm_id_priv->pkey = param->primary_path->pkey;
cm_id_priv->qp_type = param->qp_type;

ret = cm_alloc_msg(cm_id_priv, &cm_id_priv->msg);
if (ret)
goto out;
spin_lock_irqsave(&cm_id_priv->lock, flags);
msg = cm_alloc_priv_msg(cm_id_priv);
if (IS_ERR(msg)) {
ret = PTR_ERR(msg);
goto out_unlock;
}

req_msg = (struct cm_req_msg *) cm_id_priv->msg->mad;
req_msg = (struct cm_req_msg *)msg->mad;
cm_format_req(req_msg, cm_id_priv, param);
cm_id_priv->tid = req_msg->hdr.tid;
cm_id_priv->msg->timeout_ms = cm_id_priv->timeout_ms;
cm_id_priv->msg->context[1] = (void *) (unsigned long) IB_CM_REQ_SENT;
msg->timeout_ms = cm_id_priv->timeout_ms;
msg->context[1] = (void *)(unsigned long)IB_CM_REQ_SENT;

cm_id_priv->local_qpn = cpu_to_be32(IBA_GET(CM_REQ_LOCAL_QPN, req_msg));
cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REQ_STARTING_PSN, req_msg));

trace_icm_send_req(&cm_id_priv->id);
spin_lock_irqsave(&cm_id_priv->lock, flags);
ret = ib_post_send_mad(cm_id_priv->msg, NULL);
if (ret) {
cm_free_msg(cm_id_priv->msg);
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
goto out;
}
ret = ib_post_send_mad(msg, NULL);
if (ret)
goto out_free;
BUG_ON(cm_id->state != IB_CM_IDLE);
cm_id->state = IB_CM_REQ_SENT;
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
return 0;
out_free:
cm_free_priv_msg(msg);
out_unlock:
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
out:
return ret;
}
Expand Down Expand Up @@ -2290,9 +2323,11 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id,
goto out;
}

ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
msg = cm_alloc_priv_msg(cm_id_priv);
if (IS_ERR(msg)) {
ret = PTR_ERR(msg);
goto out;
}

rep_msg = (struct cm_rep_msg *) msg->mad;
cm_format_rep(rep_msg, cm_id_priv, param);
Expand All @@ -2301,23 +2336,24 @@ int ib_send_cm_rep(struct ib_cm_id *cm_id,

trace_icm_send_rep(cm_id);
ret = ib_post_send_mad(msg, NULL);
if (ret) {
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
cm_free_msg(msg);
return ret;
}
if (ret)
goto out_free;

cm_id->state = IB_CM_REP_SENT;
cm_id_priv->msg = msg;
cm_id_priv->initiator_depth = param->initiator_depth;
cm_id_priv->responder_resources = param->responder_resources;
cm_id_priv->rq_psn = cpu_to_be32(IBA_GET(CM_REP_STARTING_PSN, rep_msg));
WARN_ONCE(param->qp_num & 0xFF000000,
"IBTA declares QPN to be 24 bits, but it is 0x%X\n",
param->qp_num);
cm_id_priv->local_qpn = cpu_to_be32(param->qp_num & 0xFFFFFF);
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
return 0;

out: spin_unlock_irqrestore(&cm_id_priv->lock, flags);
out_free:
cm_free_priv_msg(msg);
out:
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
return ret;
}
EXPORT_SYMBOL(ib_send_cm_rep);
Expand Down Expand Up @@ -2364,9 +2400,11 @@ int ib_send_cm_rtu(struct ib_cm_id *cm_id,
goto error;
}

ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
msg = cm_alloc_msg(cm_id_priv);
if (IS_ERR(msg)) {
ret = PTR_ERR(msg);
goto error;
}

cm_format_rtu((struct cm_rtu_msg *) msg->mad, cm_id_priv,
private_data, private_data_len);
Expand Down Expand Up @@ -2664,10 +2702,10 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
cm_id_priv->id.lap_state == IB_CM_MRA_LAP_RCVD)
ib_cancel_mad(cm_id_priv->av.port->mad_agent, cm_id_priv->msg);

ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret) {
msg = cm_alloc_priv_msg(cm_id_priv);
if (IS_ERR(msg)) {
cm_enter_timewait(cm_id_priv);
return ret;
return PTR_ERR(msg);
}

cm_format_dreq((struct cm_dreq_msg *) msg->mad, cm_id_priv,
Expand All @@ -2679,12 +2717,11 @@ static int cm_send_dreq_locked(struct cm_id_private *cm_id_priv,
ret = ib_post_send_mad(msg, NULL);
if (ret) {
cm_enter_timewait(cm_id_priv);
cm_free_msg(msg);
cm_free_priv_msg(msg);
return ret;
}

cm_id_priv->id.state = IB_CM_DREQ_SENT;
cm_id_priv->msg = msg;
return 0;
}

Expand Down Expand Up @@ -2739,9 +2776,9 @@ static int cm_send_drep_locked(struct cm_id_private *cm_id_priv,
cm_set_private_data(cm_id_priv, private_data, private_data_len);
cm_enter_timewait(cm_id_priv);

ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
return ret;
msg = cm_alloc_msg(cm_id_priv);
if (IS_ERR(msg))
return PTR_ERR(msg);

cm_format_drep((struct cm_drep_msg *) msg->mad, cm_id_priv,
private_data, private_data_len);
Expand Down Expand Up @@ -2934,19 +2971,19 @@ static int cm_send_rej_locked(struct cm_id_private *cm_id_priv,
case IB_CM_REP_RCVD:
case IB_CM_MRA_REP_SENT:
cm_reset_to_idle(cm_id_priv);
ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
return ret;
msg = cm_alloc_msg(cm_id_priv);
if (IS_ERR(msg))
return PTR_ERR(msg);
cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
ari, ari_length, private_data, private_data_len,
state);
break;
case IB_CM_REP_SENT:
case IB_CM_MRA_REP_RCVD:
cm_enter_timewait(cm_id_priv);
ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
return ret;
msg = cm_alloc_msg(cm_id_priv);
if (IS_ERR(msg))
return PTR_ERR(msg);
cm_format_rej((struct cm_rej_msg *)msg->mad, cm_id_priv, reason,
ari, ari_length, private_data, private_data_len,
state);
Expand Down Expand Up @@ -3124,21 +3161,23 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id,
default:
trace_icm_send_mra_unknown_err(&cm_id_priv->id);
ret = -EINVAL;
goto error1;
goto error_unlock;
}

if (!(service_timeout & IB_CM_MRA_FLAG_DELAY)) {
ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
goto error1;
msg = cm_alloc_msg(cm_id_priv);
if (IS_ERR(msg)) {
ret = PTR_ERR(msg);
goto error_unlock;
}

cm_format_mra((struct cm_mra_msg *) msg->mad, cm_id_priv,
msg_response, service_timeout,
private_data, private_data_len);
trace_icm_send_mra(cm_id);
ret = ib_post_send_mad(msg, NULL);
if (ret)
goto error2;
goto error_free_msg;
}

cm_id->state = cm_state;
Expand All @@ -3148,13 +3187,11 @@ int ib_send_cm_mra(struct ib_cm_id *cm_id,
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
return 0;

error1: spin_unlock_irqrestore(&cm_id_priv->lock, flags);
kfree(data);
return ret;

error2: spin_unlock_irqrestore(&cm_id_priv->lock, flags);
kfree(data);
error_free_msg:
cm_free_msg(msg);
error_unlock:
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
kfree(data);
return ret;
}
EXPORT_SYMBOL(ib_send_cm_mra);
Expand Down Expand Up @@ -3490,38 +3527,41 @@ int ib_send_cm_sidr_req(struct ib_cm_id *cm_id,
&cm_id_priv->av,
cm_id_priv);
if (ret)
goto out;
return ret;

cm_id->service_id = param->service_id;
cm_id->service_mask = ~cpu_to_be64(0);
cm_id_priv->timeout_ms = param->timeout_ms;
cm_id_priv->max_cm_retries = param->max_cm_retries;
ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
goto out;

cm_format_sidr_req((struct cm_sidr_req_msg *) msg->mad, cm_id_priv,
param);
msg->timeout_ms = cm_id_priv->timeout_ms;
msg->context[1] = (void *) (unsigned long) IB_CM_SIDR_REQ_SENT;

spin_lock_irqsave(&cm_id_priv->lock, flags);
if (cm_id->state == IB_CM_IDLE) {
trace_icm_send_sidr_req(&cm_id_priv->id);
ret = ib_post_send_mad(msg, NULL);
} else {
if (cm_id->state != IB_CM_IDLE) {
ret = -EINVAL;
goto out_unlock;
}

if (ret) {
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
cm_free_msg(msg);
goto out;
msg = cm_alloc_priv_msg(cm_id_priv);
if (IS_ERR(msg)) {
ret = PTR_ERR(msg);
goto out_unlock;
}

cm_format_sidr_req((struct cm_sidr_req_msg *)msg->mad, cm_id_priv,
param);
msg->timeout_ms = cm_id_priv->timeout_ms;
msg->context[1] = (void *)(unsigned long)IB_CM_SIDR_REQ_SENT;

trace_icm_send_sidr_req(&cm_id_priv->id);
ret = ib_post_send_mad(msg, NULL);
if (ret)
goto out_free;
cm_id->state = IB_CM_SIDR_REQ_SENT;
cm_id_priv->msg = msg;
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
out:
return 0;
out_free:
cm_free_priv_msg(msg);
out_unlock:
spin_unlock_irqrestore(&cm_id_priv->lock, flags);
return ret;
}
EXPORT_SYMBOL(ib_send_cm_sidr_req);
Expand Down Expand Up @@ -3668,9 +3708,9 @@ static int cm_send_sidr_rep_locked(struct cm_id_private *cm_id_priv,
if (cm_id_priv->id.state != IB_CM_SIDR_REQ_RCVD)
return -EINVAL;

ret = cm_alloc_msg(cm_id_priv, &msg);
if (ret)
return ret;
msg = cm_alloc_msg(cm_id_priv);
if (IS_ERR(msg))
return PTR_ERR(msg);

cm_format_sidr_rep((struct cm_sidr_rep_msg *) msg->mad, cm_id_priv,
param);
Expand Down

0 comments on commit 020155e

Please sign in to comment.