Skip to content

Commit

Permalink
ntsync: Introduce NTSYNC_IOC_CREATE_MUTEX.
Browse files Browse the repository at this point in the history
This corresponds to the NT syscall NtCreateMutant().

An NT mutex is recursive, with a 32-bit recursion counter. When acquired via
NtWaitForMultipleObjects(), the recursion counter is incremented by one.

The OS records the thread which acquired it. However, in order to keep this
driver self-contained, the owning thread ID is managed by user-space, and passed
as a parameter to all relevant ioctls.

The initial owner and recursion count, if any, are specified when the mutex is
created.

Signed-off-by: Elizabeth Figura <zfigura@codeweavers.com>
Signed-off-by: Alexandre Frade <kernel@xanmod.org>
  • Loading branch information
Elizabeth Figura authored and xanmod committed Mar 1, 2024
1 parent 3d82a2f commit cd2da9e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 0 deletions.
67 changes: 67 additions & 0 deletions drivers/misc/ntsync.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

enum ntsync_type {
NTSYNC_TYPE_SEM,
NTSYNC_TYPE_MUTEX,
};

/*
Expand Down Expand Up @@ -53,6 +54,10 @@ struct ntsync_obj {
__u32 count;
__u32 max;
} sem;
struct {
__u32 count;
__u32 owner;
} mutex;
} u;

/*
Expand Down Expand Up @@ -132,6 +137,10 @@ static bool is_signaled(struct ntsync_obj *obj, __u32 owner)
switch (obj->type) {
case NTSYNC_TYPE_SEM:
return !!obj->u.sem.count;
case NTSYNC_TYPE_MUTEX:
if (obj->u.mutex.owner && obj->u.mutex.owner != owner)
return false;
return obj->u.mutex.count < UINT_MAX;
}

WARN(1, "bad object type %#x\n", obj->type);
Expand Down Expand Up @@ -174,6 +183,10 @@ static void try_wake_all(struct ntsync_device *dev, struct ntsync_q *q,
case NTSYNC_TYPE_SEM:
obj->u.sem.count--;
break;
case NTSYNC_TYPE_MUTEX:
obj->u.mutex.count++;
obj->u.mutex.owner = q->owner;
break;
}
}
wake_up_process(q->task);
Expand Down Expand Up @@ -215,6 +228,28 @@ static void try_wake_any_sem(struct ntsync_obj *sem)
}
}

static void try_wake_any_mutex(struct ntsync_obj *mutex)
{
struct ntsync_q_entry *entry;

lockdep_assert_held(&mutex->lock);

list_for_each_entry(entry, &mutex->any_waiters, node) {
struct ntsync_q *q = entry->q;

if (mutex->u.mutex.count == UINT_MAX)
break;
if (mutex->u.mutex.owner && mutex->u.mutex.owner != q->owner)
continue;

if (atomic_cmpxchg(&q->signaled, -1, entry->index) == -1) {
mutex->u.mutex.count++;
mutex->u.mutex.owner = q->owner;
wake_up_process(q->task);
}
}
}

/*
* Actually change the semaphore state, returning -EOVERFLOW if it is made
* invalid.
Expand Down Expand Up @@ -374,6 +409,33 @@ static int ntsync_create_sem(struct ntsync_device *dev, void __user *argp)
return put_user(fd, &user_args->sem);
}

static int ntsync_create_mutex(struct ntsync_device *dev, void __user *argp)
{
struct ntsync_mutex_args __user *user_args = argp;
struct ntsync_mutex_args args;
struct ntsync_obj *mutex;
int fd;

if (copy_from_user(&args, argp, sizeof(args)))
return -EFAULT;

if (!args.owner != !args.count)
return -EINVAL;

mutex = ntsync_alloc_obj(dev, NTSYNC_TYPE_MUTEX);
if (!mutex)
return -ENOMEM;
mutex->u.mutex.count = args.count;
mutex->u.mutex.owner = args.owner;
fd = ntsync_obj_get_fd(mutex);
if (fd < 0) {
kfree(mutex);
return fd;
}

return put_user(fd, &user_args->mutex);
}

static struct ntsync_obj *get_obj(struct ntsync_device *dev, int fd)
{
struct file *file = fget(fd);
Expand Down Expand Up @@ -493,6 +555,9 @@ static void try_wake_any_obj(struct ntsync_obj *obj)
case NTSYNC_TYPE_SEM:
try_wake_any_sem(obj);
break;
case NTSYNC_TYPE_MUTEX:
try_wake_any_mutex(obj);
break;
}
}

Expand Down Expand Up @@ -681,6 +746,8 @@ static long ntsync_char_ioctl(struct file *file, unsigned int cmd,
void __user *argp = (void __user *)parm;

switch (cmd) {
case NTSYNC_IOC_CREATE_MUTEX:
return ntsync_create_mutex(dev, argp);
case NTSYNC_IOC_CREATE_SEM:
return ntsync_create_sem(dev, argp);
case NTSYNC_IOC_WAIT_ALL:
Expand Down
7 changes: 7 additions & 0 deletions include/uapi/linux/ntsync.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ struct ntsync_sem_args {
__u32 max;
};

struct ntsync_mutex_args {
__u32 mutex;
__u32 owner;
__u32 count;
};

struct ntsync_wait_args {
__u64 timeout;
__u64 objs;
Expand All @@ -30,6 +36,7 @@ struct ntsync_wait_args {
#define NTSYNC_IOC_CREATE_SEM _IOWR('N', 0x80, struct ntsync_sem_args)
#define NTSYNC_IOC_WAIT_ANY _IOWR('N', 0x82, struct ntsync_wait_args)
#define NTSYNC_IOC_WAIT_ALL _IOWR('N', 0x83, struct ntsync_wait_args)
#define NTSYNC_IOC_CREATE_MUTEX _IOWR('N', 0x84, struct ntsync_sem_args)

#define NTSYNC_IOC_SEM_POST _IOWR('N', 0x81, __u32)

Expand Down

0 comments on commit cd2da9e

Please sign in to comment.