diff --git a/drivers/usb/usbip/vhci_sysfs.c b/drivers/usb/usbip/vhci_sysfs.c index 18a0fa649cd41..e64ea314930be 100644 --- a/drivers/usb/usbip/vhci_sysfs.c +++ b/drivers/usb/usbip/vhci_sysfs.c @@ -312,6 +312,8 @@ static ssize_t attach_store(struct device *dev, struct device_attribute *attr, struct vhci *vhci; int err; unsigned long flags; + struct task_struct *tcp_rx = NULL; + struct task_struct *tcp_tx = NULL; /* * @rhport: port number of vhci_hcd @@ -360,9 +362,24 @@ static ssize_t attach_store(struct device *dev, struct device_attribute *attr, return -EINVAL; } - /* now need lock until setting vdev status as used */ + /* create threads before locking */ + tcp_rx = kthread_create(vhci_rx_loop, &vdev->ud, "vhci_rx"); + if (IS_ERR(tcp_rx)) { + sockfd_put(socket); + return -EINVAL; + } + tcp_tx = kthread_create(vhci_tx_loop, &vdev->ud, "vhci_tx"); + if (IS_ERR(tcp_tx)) { + kthread_stop(tcp_rx); + sockfd_put(socket); + return -EINVAL; + } + + /* get task structs now */ + get_task_struct(tcp_rx); + get_task_struct(tcp_tx); - /* begin a lock */ + /* now begin lock until setting vdev status set */ spin_lock_irqsave(&vhci->lock, flags); spin_lock(&vdev->ud.lock); @@ -372,6 +389,8 @@ static ssize_t attach_store(struct device *dev, struct device_attribute *attr, spin_unlock_irqrestore(&vhci->lock, flags); sockfd_put(socket); + kthread_stop_put(tcp_rx); + kthread_stop_put(tcp_tx); dev_err(dev, "port %d already used\n", rhport); /* @@ -390,14 +409,16 @@ static ssize_t attach_store(struct device *dev, struct device_attribute *attr, vdev->speed = speed; vdev->ud.sockfd = sockfd; vdev->ud.tcp_socket = socket; + vdev->ud.tcp_rx = tcp_rx; + vdev->ud.tcp_tx = tcp_tx; vdev->ud.status = VDEV_ST_NOTASSIGNED; spin_unlock(&vdev->ud.lock); spin_unlock_irqrestore(&vhci->lock, flags); /* end the lock */ - vdev->ud.tcp_rx = kthread_get_run(vhci_rx_loop, &vdev->ud, "vhci_rx"); - vdev->ud.tcp_tx = kthread_get_run(vhci_tx_loop, &vdev->ud, "vhci_tx"); + wake_up_process(vdev->ud.tcp_rx); + wake_up_process(vdev->ud.tcp_tx); rh_port_connect(vdev, speed);