@@ -456,6 +456,8 @@ struct io_ring_ctx {
456456 struct work_struct exit_work ;
457457 struct list_head tctx_list ;
458458 struct completion ref_comp ;
459+ u32 iowq_limits [2 ];
460+ bool iowq_limits_set ;
459461 };
460462};
461463
@@ -1368,11 +1370,6 @@ static void io_req_track_inflight(struct io_kiocb *req)
13681370 }
13691371}
13701372
1371- static inline void io_unprep_linked_timeout (struct io_kiocb * req )
1372- {
1373- req -> flags &= ~REQ_F_LINK_TIMEOUT ;
1374- }
1375-
13761373static struct io_kiocb * __io_prep_linked_timeout (struct io_kiocb * req )
13771374{
13781375 if (WARN_ON_ONCE (!req -> link ))
@@ -6983,7 +6980,7 @@ static void __io_queue_sqe(struct io_kiocb *req)
69836980 switch (io_arm_poll_handler (req )) {
69846981 case IO_APOLL_READY :
69856982 if (linked_timeout )
6986- io_unprep_linked_timeout ( req );
6983+ io_queue_linked_timeout ( linked_timeout );
69876984 goto issue_sqe ;
69886985 case IO_APOLL_ABORTED :
69896986 /*
@@ -9638,7 +9635,16 @@ static int __io_uring_add_tctx_node(struct io_ring_ctx *ctx)
96389635 ret = io_uring_alloc_task_context (current , ctx );
96399636 if (unlikely (ret ))
96409637 return ret ;
9638+
96419639 tctx = current -> io_uring ;
9640+ if (ctx -> iowq_limits_set ) {
9641+ unsigned int limits [2 ] = { ctx -> iowq_limits [0 ],
9642+ ctx -> iowq_limits [1 ], };
9643+
9644+ ret = io_wq_max_workers (tctx -> io_wq , limits );
9645+ if (ret )
9646+ return ret ;
9647+ }
96429648 }
96439649 if (!xa_load (& tctx -> xa , (unsigned long )ctx )) {
96449650 node = kmalloc (sizeof (* node ), GFP_KERNEL );
@@ -10643,7 +10649,9 @@ static int io_unregister_iowq_aff(struct io_ring_ctx *ctx)
1064310649
1064410650static int io_register_iowq_max_workers (struct io_ring_ctx * ctx ,
1064510651 void __user * arg )
10652+ __must_hold (& ctx - > uring_lock )
1064610653{
10654+ struct io_tctx_node * node ;
1064710655 struct io_uring_task * tctx = NULL ;
1064810656 struct io_sq_data * sqd = NULL ;
1064910657 __u32 new_count [2 ];
@@ -10674,13 +10682,19 @@ static int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
1067410682 tctx = current -> io_uring ;
1067510683 }
1067610684
10677- ret = - EINVAL ;
10678- if (!tctx || !tctx -> io_wq )
10679- goto err ;
10685+ BUILD_BUG_ON (sizeof (new_count ) != sizeof (ctx -> iowq_limits ));
1068010686
10681- ret = io_wq_max_workers (tctx -> io_wq , new_count );
10682- if (ret )
10683- goto err ;
10687+ memcpy (ctx -> iowq_limits , new_count , sizeof (new_count ));
10688+ ctx -> iowq_limits_set = true;
10689+
10690+ ret = - EINVAL ;
10691+ if (tctx && tctx -> io_wq ) {
10692+ ret = io_wq_max_workers (tctx -> io_wq , new_count );
10693+ if (ret )
10694+ goto err ;
10695+ } else {
10696+ memset (new_count , 0 , sizeof (new_count ));
10697+ }
1068410698
1068510699 if (sqd ) {
1068610700 mutex_unlock (& sqd -> lock );
@@ -10690,6 +10704,22 @@ static int io_register_iowq_max_workers(struct io_ring_ctx *ctx,
1069010704 if (copy_to_user (arg , new_count , sizeof (new_count )))
1069110705 return - EFAULT ;
1069210706
10707+ /* that's it for SQPOLL, only the SQPOLL task creates requests */
10708+ if (sqd )
10709+ return 0 ;
10710+
10711+ /* now propagate the restriction to all registered users */
10712+ list_for_each_entry (node , & ctx -> tctx_list , ctx_node ) {
10713+ struct io_uring_task * tctx = node -> task -> io_uring ;
10714+
10715+ if (WARN_ON_ONCE (!tctx -> io_wq ))
10716+ continue ;
10717+
10718+ for (i = 0 ; i < ARRAY_SIZE (new_count ); i ++ )
10719+ new_count [i ] = ctx -> iowq_limits [i ];
10720+ /* ignore errors, it always returns zero anyway */
10721+ (void )io_wq_max_workers (tctx -> io_wq , new_count );
10722+ }
1069310723 return 0 ;
1069410724err :
1069510725 if (sqd ) {
0 commit comments