Skip to content

Commit

Permalink
vsock/virtio: add transport parameter to the virtio_transport_reset_n…
Browse files Browse the repository at this point in the history
…o_sock()

[ Upstream commit 4c7246d ]

We are going to add 'struct vsock_sock *' parameter to
virtio_transport_get_ops().

In some cases, like in the virtio_transport_reset_no_sock(),
we don't have any socket assigned to the packet received,
so we can't use the virtio_transport_get_ops().

In order to allow virtio_transport_reset_no_sock() to use the
'.send_pkt' callback from the 'vhost_transport' or 'virtio_transport',
we add the 'struct virtio_transport *' to it and to its caller:
virtio_transport_recv_pkt().

We moved the 'vhost_transport' and 'virtio_transport' definition,
to pass their address to the virtio_transport_recv_pkt().

Reviewed-by: Stefan Hajnoczi <stefanha@redhat.com>
Signed-off-by: Stefano Garzarella <sgarzare@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
Signed-off-by: Sasha Levin <sashal@kernel.org>
  • Loading branch information
stefano-garzarella authored and gregkh committed Oct 7, 2020
1 parent 14c79ef commit 215459f
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 134 deletions.
94 changes: 47 additions & 47 deletions drivers/vhost/vsock.c
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,52 @@ static bool vhost_vsock_more_replies(struct vhost_vsock *vsock)
return val < vq->num;
}

static struct virtio_transport vhost_transport = {
.transport = {
.get_local_cid = vhost_transport_get_local_cid,

.init = virtio_transport_do_socket_init,
.destruct = virtio_transport_destruct,
.release = virtio_transport_release,
.connect = virtio_transport_connect,
.shutdown = virtio_transport_shutdown,
.cancel_pkt = vhost_transport_cancel_pkt,

.dgram_enqueue = virtio_transport_dgram_enqueue,
.dgram_dequeue = virtio_transport_dgram_dequeue,
.dgram_bind = virtio_transport_dgram_bind,
.dgram_allow = virtio_transport_dgram_allow,

.stream_enqueue = virtio_transport_stream_enqueue,
.stream_dequeue = virtio_transport_stream_dequeue,
.stream_has_data = virtio_transport_stream_has_data,
.stream_has_space = virtio_transport_stream_has_space,
.stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,

.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
.notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
.notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
.notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
.notify_send_init = virtio_transport_notify_send_init,
.notify_send_pre_block = virtio_transport_notify_send_pre_block,
.notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,

.set_buffer_size = virtio_transport_set_buffer_size,
.set_min_buffer_size = virtio_transport_set_min_buffer_size,
.set_max_buffer_size = virtio_transport_set_max_buffer_size,
.get_buffer_size = virtio_transport_get_buffer_size,
.get_min_buffer_size = virtio_transport_get_min_buffer_size,
.get_max_buffer_size = virtio_transport_get_max_buffer_size,
},

.send_pkt = vhost_transport_send_pkt,
};

static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
{
struct vhost_virtqueue *vq = container_of(work, struct vhost_virtqueue,
Expand Down Expand Up @@ -440,7 +486,7 @@ static void vhost_vsock_handle_tx_kick(struct vhost_work *work)
if (le64_to_cpu(pkt->hdr.src_cid) == vsock->guest_cid &&
le64_to_cpu(pkt->hdr.dst_cid) ==
vhost_transport_get_local_cid())
virtio_transport_recv_pkt(pkt);
virtio_transport_recv_pkt(&vhost_transport, pkt);
else
virtio_transport_free_pkt(pkt);

Expand Down Expand Up @@ -793,52 +839,6 @@ static struct miscdevice vhost_vsock_misc = {
.fops = &vhost_vsock_fops,
};

static struct virtio_transport vhost_transport = {
.transport = {
.get_local_cid = vhost_transport_get_local_cid,

.init = virtio_transport_do_socket_init,
.destruct = virtio_transport_destruct,
.release = virtio_transport_release,
.connect = virtio_transport_connect,
.shutdown = virtio_transport_shutdown,
.cancel_pkt = vhost_transport_cancel_pkt,

.dgram_enqueue = virtio_transport_dgram_enqueue,
.dgram_dequeue = virtio_transport_dgram_dequeue,
.dgram_bind = virtio_transport_dgram_bind,
.dgram_allow = virtio_transport_dgram_allow,

.stream_enqueue = virtio_transport_stream_enqueue,
.stream_dequeue = virtio_transport_stream_dequeue,
.stream_has_data = virtio_transport_stream_has_data,
.stream_has_space = virtio_transport_stream_has_space,
.stream_rcvhiwat = virtio_transport_stream_rcvhiwat,
.stream_is_active = virtio_transport_stream_is_active,
.stream_allow = virtio_transport_stream_allow,

.notify_poll_in = virtio_transport_notify_poll_in,
.notify_poll_out = virtio_transport_notify_poll_out,
.notify_recv_init = virtio_transport_notify_recv_init,
.notify_recv_pre_block = virtio_transport_notify_recv_pre_block,
.notify_recv_pre_dequeue = virtio_transport_notify_recv_pre_dequeue,
.notify_recv_post_dequeue = virtio_transport_notify_recv_post_dequeue,
.notify_send_init = virtio_transport_notify_send_init,
.notify_send_pre_block = virtio_transport_notify_send_pre_block,
.notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue,
.notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue,

.set_buffer_size = virtio_transport_set_buffer_size,
.set_min_buffer_size = virtio_transport_set_min_buffer_size,
.set_max_buffer_size = virtio_transport_set_max_buffer_size,
.get_buffer_size = virtio_transport_get_buffer_size,
.get_min_buffer_size = virtio_transport_get_min_buffer_size,
.get_max_buffer_size = virtio_transport_get_max_buffer_size,
},

.send_pkt = vhost_transport_send_pkt,
};

static int __init vhost_vsock_init(void)
{
int ret;
Expand Down
3 changes: 2 additions & 1 deletion include/linux/virtio_vsock.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,8 @@ virtio_transport_dgram_enqueue(struct vsock_sock *vsk,

void virtio_transport_destruct(struct vsock_sock *vsk);

void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt);
void virtio_transport_recv_pkt(struct virtio_transport *t,
struct virtio_vsock_pkt *pkt);
void virtio_transport_free_pkt(struct virtio_vsock_pkt *pkt);
void virtio_transport_inc_tx_pkt(struct virtio_vsock_sock *vvs, struct virtio_vsock_pkt *pkt);
u32 virtio_transport_get_credit(struct virtio_vsock_sock *vvs, u32 wanted);
Expand Down
160 changes: 80 additions & 80 deletions net/vmw_vsock/virtio_transport.c
Original file line number Diff line number Diff line change
Expand Up @@ -86,33 +86,6 @@ static u32 virtio_transport_get_local_cid(void)
return ret;
}

static void virtio_transport_loopback_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, loopback_work);
LIST_HEAD(pkts);

spin_lock_bh(&vsock->loopback_list_lock);
list_splice_init(&vsock->loopback_list, &pkts);
spin_unlock_bh(&vsock->loopback_list_lock);

mutex_lock(&vsock->rx_lock);

if (!vsock->rx_run)
goto out;

while (!list_empty(&pkts)) {
struct virtio_vsock_pkt *pkt;

pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list);
list_del_init(&pkt->list);

virtio_transport_recv_pkt(pkt);
}
out:
mutex_unlock(&vsock->rx_lock);
}

static int virtio_transport_send_pkt_loopback(struct virtio_vsock *vsock,
struct virtio_vsock_pkt *pkt)
{
Expand Down Expand Up @@ -370,59 +343,6 @@ static bool virtio_transport_more_replies(struct virtio_vsock *vsock)
return val < virtqueue_get_vring_size(vq);
}

static void virtio_transport_rx_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, rx_work);
struct virtqueue *vq;

vq = vsock->vqs[VSOCK_VQ_RX];

mutex_lock(&vsock->rx_lock);

if (!vsock->rx_run)
goto out;

do {
virtqueue_disable_cb(vq);
for (;;) {
struct virtio_vsock_pkt *pkt;
unsigned int len;

if (!virtio_transport_more_replies(vsock)) {
/* Stop rx until the device processes already
* pending replies. Leave rx virtqueue
* callbacks disabled.
*/
goto out;
}

pkt = virtqueue_get_buf(vq, &len);
if (!pkt) {
break;
}

vsock->rx_buf_nr--;

/* Drop short/long packets */
if (unlikely(len < sizeof(pkt->hdr) ||
len > sizeof(pkt->hdr) + pkt->len)) {
virtio_transport_free_pkt(pkt);
continue;
}

pkt->len = len - sizeof(pkt->hdr);
virtio_transport_deliver_tap_pkt(pkt);
virtio_transport_recv_pkt(pkt);
}
} while (!virtqueue_enable_cb(vq));

out:
if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
virtio_vsock_rx_fill(vsock);
mutex_unlock(&vsock->rx_lock);
}

/* event_lock must be held */
static int virtio_vsock_event_fill_one(struct virtio_vsock *vsock,
struct virtio_vsock_event *event)
Expand Down Expand Up @@ -586,6 +506,86 @@ static struct virtio_transport virtio_transport = {
.send_pkt = virtio_transport_send_pkt,
};

static void virtio_transport_loopback_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, loopback_work);
LIST_HEAD(pkts);

spin_lock_bh(&vsock->loopback_list_lock);
list_splice_init(&vsock->loopback_list, &pkts);
spin_unlock_bh(&vsock->loopback_list_lock);

mutex_lock(&vsock->rx_lock);

if (!vsock->rx_run)
goto out;

while (!list_empty(&pkts)) {
struct virtio_vsock_pkt *pkt;

pkt = list_first_entry(&pkts, struct virtio_vsock_pkt, list);
list_del_init(&pkt->list);

virtio_transport_recv_pkt(&virtio_transport, pkt);
}
out:
mutex_unlock(&vsock->rx_lock);
}

static void virtio_transport_rx_work(struct work_struct *work)
{
struct virtio_vsock *vsock =
container_of(work, struct virtio_vsock, rx_work);
struct virtqueue *vq;

vq = vsock->vqs[VSOCK_VQ_RX];

mutex_lock(&vsock->rx_lock);

if (!vsock->rx_run)
goto out;

do {
virtqueue_disable_cb(vq);
for (;;) {
struct virtio_vsock_pkt *pkt;
unsigned int len;

if (!virtio_transport_more_replies(vsock)) {
/* Stop rx until the device processes already
* pending replies. Leave rx virtqueue
* callbacks disabled.
*/
goto out;
}

pkt = virtqueue_get_buf(vq, &len);
if (!pkt) {
break;
}

vsock->rx_buf_nr--;

/* Drop short/long packets */
if (unlikely(len < sizeof(pkt->hdr) ||
len > sizeof(pkt->hdr) + pkt->len)) {
virtio_transport_free_pkt(pkt);
continue;
}

pkt->len = len - sizeof(pkt->hdr);
virtio_transport_deliver_tap_pkt(pkt);
virtio_transport_recv_pkt(&virtio_transport, pkt);
}
} while (!virtqueue_enable_cb(vq));

out:
if (vsock->rx_buf_nr < vsock->rx_buf_max_nr / 2)
virtio_vsock_rx_fill(vsock);
mutex_unlock(&vsock->rx_lock);
}

static int virtio_vsock_probe(struct virtio_device *vdev)
{
vq_callback_t *callbacks[] = {
Expand Down
12 changes: 6 additions & 6 deletions net/vmw_vsock/virtio_transport_common.c
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,9 @@ static int virtio_transport_reset(struct vsock_sock *vsk,
/* Normally packets are associated with a socket. There may be no socket if an
* attempt was made to connect to a socket that does not exist.
*/
static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
static int virtio_transport_reset_no_sock(const struct virtio_transport *t,
struct virtio_vsock_pkt *pkt)
{
const struct virtio_transport *t;
struct virtio_vsock_pkt *reply;
struct virtio_vsock_pkt_info info = {
.op = VIRTIO_VSOCK_OP_RST,
Expand All @@ -718,7 +718,6 @@ static int virtio_transport_reset_no_sock(struct virtio_vsock_pkt *pkt)
if (!reply)
return -ENOMEM;

t = virtio_transport_get_ops();
if (!t) {
virtio_transport_free_pkt(reply);
return -ENOTCONN;
Expand Down Expand Up @@ -1060,7 +1059,8 @@ static bool virtio_transport_space_update(struct sock *sk,
/* We are under the virtio-vsock's vsock->rx_lock or vhost-vsock's vq->mutex
* lock.
*/
void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
void virtio_transport_recv_pkt(struct virtio_transport *t,
struct virtio_vsock_pkt *pkt)
{
struct sockaddr_vm src, dst;
struct vsock_sock *vsk;
Expand All @@ -1082,7 +1082,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
le32_to_cpu(pkt->hdr.fwd_cnt));

if (le16_to_cpu(pkt->hdr.type) != VIRTIO_VSOCK_TYPE_STREAM) {
(void)virtio_transport_reset_no_sock(pkt);
(void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt;
}

Expand All @@ -1093,7 +1093,7 @@ void virtio_transport_recv_pkt(struct virtio_vsock_pkt *pkt)
if (!sk) {
sk = vsock_find_bound_socket(&dst);
if (!sk) {
(void)virtio_transport_reset_no_sock(pkt);
(void)virtio_transport_reset_no_sock(t, pkt);
goto free_pkt;
}
}
Expand Down

0 comments on commit 215459f

Please sign in to comment.