Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 69 additions & 6 deletions aten/src/TH/THAllocator.c
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ struct THMapAllocatorContext_ {
ptrdiff_t size; /* mapped size */
#ifdef _WIN32
HANDLE handle;
HANDLE event;
char *eventname;
#else
int fd;
#endif
Expand All @@ -65,6 +67,9 @@ typedef struct {
} THMapInfo;

char * unknown_filename = "filename not specified";
#ifdef _WIN32
char * unknown_eventname = "eventname not specified";
#endif

THMapAllocatorContext *THMapAllocatorContext_new(const char *filename, int flags)
{
Expand All @@ -79,9 +84,20 @@ THMapAllocatorContext *THMapAllocatorContext_new(const char *filename, int flags
if (filename) {
ctx->filename = THAlloc(strlen(filename)+1);
strcpy(ctx->filename, filename);
#ifdef _WIN32
char *suffixname = "_event";
size_t namelen = strlen(filename)+1+strlen(suffixname);
ctx->eventname = THAlloc(namelen);
strcpy(ctx->eventname, ctx->filename);
strcat(ctx->eventname, suffixname);
#endif
} else {
ctx->filename = unknown_filename;
#ifdef _WIN32
ctx->eventname = unknown_eventname;
#endif
}

ctx->flags = flags;
ctx->size = 0;
#ifdef _WIN32
Expand Down Expand Up @@ -126,11 +142,37 @@ ptrdiff_t THMapAllocatorContext_size(THMapAllocatorContext *ctx)

void THMapAllocatorContext_free(THMapAllocatorContext *ctx)
{
if (ctx->filename != unknown_filename)
if (ctx->filename != unknown_filename) {
THFree(ctx->filename);
#ifdef _WIN32
THFree(ctx->eventname);
#endif
}
THFree(ctx);
}

#ifdef _WIN32
typedef struct{
HANDLE event;
HANDLE handle;
HANDLE wait;
} ReleaseContext;
static VOID CALLBACK WaitForReleaseHandle(PVOID lpParam, BOOLEAN TimerOrWaitFired)
{
if (lpParam) {
ReleaseContext *ctx = (ReleaseContext *)lpParam;

SetEvent(ctx->event);
CloseHandle(ctx->event);
CloseHandle(ctx->handle);

UnregisterWait(ctx->wait);

THFree(ctx);
}
}
#endif

static void *_map_alloc(void* ctx_, ptrdiff_t size)
{
if (size == 0)
Expand All @@ -143,28 +185,38 @@ static void *_map_alloc(void* ctx_, ptrdiff_t size)
if (ctx->flags & TH_ALLOCATOR_MAPPED_SHAREDMEM)
{
char *filename;
char *eventname;
LARGE_INTEGER hfilesz;

if (ctx->filename[0] == '/')
if (ctx->filename[0] == '/') {
filename = ctx->filename + 1;
else
eventname = ctx->eventname + 1;
}
else {
filename = ctx->filename;
eventname = ctx->eventname;
}

hfilesz.QuadPart = size;

if (ctx->flags & TH_ALLOCATOR_MAPPED_EXCLUSIVE)
{
ctx->handle = CreateFileMapping(INVALID_HANDLE_VALUE, NULL, PAGE_READWRITE, hfilesz.HighPart, hfilesz.LowPart, filename);
ctx->event = CreateEvent(NULL, FALSE, FALSE, eventname);
}
else if (ctx->flags & TH_ALLOCATOR_MAPPED_NOCREATE)
{
ctx->handle = OpenFileMapping(FILE_MAP_ALL_ACCESS, FALSE, filename);
ctx->event = OpenEvent(EVENT_ALL_ACCESS, FALSE, eventname);
}
else
{
THError("Excpected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE");
THError("Expected either TH_ALLOCATOR_MAPPED_EXCLUSIVE or TH_ALLOCATOR_MAPPED_NOCREATE");
}

if (ctx->event == NULL)
THError("Couldn't open shared event: <%s>, error code: <%d>", eventname, GetLastError());

if (ctx->handle == NULL)
THError("Couldn't open shared file mapping: <%s>, error code: <%d>", filename, GetLastError());

Expand Down Expand Up @@ -478,6 +530,17 @@ static void * THRefcountedMapAllocator_alloc(void *_ctx, ptrdiff_t size) {
char *data = ((char*)ptr) + TH_ALLOC_ALIGNMENT;
THMapInfo *map_info = (THMapInfo*)ptr;

#ifdef _WIN32
ReleaseContext* r_ctx = (ReleaseContext *) THAlloc(sizeof(ReleaseContext));
r_ctx->handle = ctx->handle;
r_ctx->event = ctx->event;
r_ctx->wait = NULL;
BOOL can_wait = RegisterWaitForSingleObject(&r_ctx->wait, ctx->event, WaitForReleaseHandle, (PVOID)r_ctx, INFINITE, WT_EXECUTEONLYONCE);
if (!can_wait) {
THError("Couldn't register wait on event, error code: <%d>", GetLastError());
}
#endif

if (ctx->flags & TH_ALLOCATOR_MAPPED_EXCLUSIVE)
map_info->refcount = 1;
else
Expand All @@ -497,9 +560,9 @@ static void THRefcountedMapAllocator_free(void* ctx_, void* data) {
#ifdef _WIN32
THMapInfo *info = (THMapInfo*)(((char*)data) - TH_ALLOC_ALIGNMENT);
if (THAtomicDecrementRef(&info->refcount)) {
CloseHandle(ctx->handle);
SetEvent(ctx->event);
}
if(UnmapViewOfFile(data) == 0)
if(UnmapViewOfFile(((char*)data) - TH_ALLOC_ALIGNMENT) == 0)
THError("could not unmap the shared memory file");
#else /* _WIN32 */

Expand Down
3 changes: 0 additions & 3 deletions torch/utils/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,9 +413,6 @@ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sam
raise ValueError('num_workers cannot be negative; '
'use num_workers=0 to disable multiprocessing.')

if sys.platform == "win32" and self.num_workers > 0:
raise ValueError('num_workers > 0 is not supported on Windows')

if batch_sampler is None:
if sampler is None:
if shuffle:
Expand Down