Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpu draft #391

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/Finch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using Compat
using DataStructures
using JSON
using CIndices
#using CUDA

export @finch, @finch_program, @finch_code, @finch_kernel, value

Expand Down Expand Up @@ -47,8 +48,11 @@ export choose, minby, maxby, overwrite, initwrite, d
export default, AsArray

export parallelAnalysis, ParallelAnalysisResults
export parallel, realextent, extent, dimless

export parallel, gpublock_parallel, gputhread_parallel, extent, dimless
export CPU, CPULocalVector, CPULocalMemory
export GPUBlock, GPUBlockLocalVector, GPUBlockLocalMemory
export GPUThread, GPUThreadkLocalVector, GPUThreadLocalMemory

export Limit, Eps

Expand Down
152 changes: 151 additions & 1 deletion src/architecture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,154 @@ end

function moveto(vec::CPULocalVector, task::CPUThread)
return vec.data[task.tid]
end
end

## GPU part
struct GPUBlock <: AbstractDevice
n::Int
end
GPUBlock() = GPUBlock(Threads.nthreads())
@kwdef struct VirtualGPUBlock <: AbstractVirtualDevice
ex
n
end
function virtualize(ex, ::Type{GPUBlock}, ctx)
sym = freshen(ctx, :gpublock)
push!(ctx.preamble, quote
$sym = $ex
end)
VirtualGPUBlock(sym, virtualize(:($sym.n), Int, ctx))
end
lower(device::VirtualGPUBlock, ctx::AbstractCompiler, ::DefaultStyle) =
something(device.ex, :(GPUBlock($(ctx(device.n)))))

FinchNotation.finch_leaf(device::VirtualGPUBlock) = virtual(device)

struct GPUThreadBlock{Parent} <: AbstractTask
tid::Int
dev::GPUBlock
parent::Parent
end
get_device(task::GPUThreadBlock) = task.device
get_task(task::GPUThreadBlock) = task.parent
struct VirtualGPUThreadBlock <: AbstractVirtualTask
tid
dev::VirtualGPUBlock
parent
end
function virtualize(ex, ::Type{GPUThreadBlock{Parent}}, ctx) where {Parent}
VirtualGPUThreadBlock(
virtualize(:($sym.tid), Int, ctx),
virtualize(:($sym.dev), GPUBlock, ctx),
virtualize(:($sym.parent), Parent, ctx)
)
end
lower(task::VirtualGPUThreadBlock, ctx::AbstractCompiler, ::DefaultStyle) = :(GPUThreadBlock($(ctx(task.tid)), $(ctx(task.dev)), $(ctx(task.parent))))
FinchNotation.finch_leaf(device::VirtualGPUThreadBlock) = virtual(device)
virtual_get_device(task::VirtualGPUThreadBlock) = task.device
virtual_get_task(task::VirtualGPUThreadBlock) = task.parent

struct GPUBlockLocalMemory
device::GPUBlock
end
function moveto(vec::V, mem::GPUBlockLocalMemory) where {V <: Vector}
GPUBlockLocalVector{V}(mem.device, [copy(vec) for _ in 1:mem.device.n])
end

struct GPUBlockLocalVector{V}
device::GPUBlock
data::Vector{V}
end

GPUBlockLocalVector{V}(device::GPUBlock) where {V} =
GPUBlockLocalVector{V}(device, [V([]) for _ in 1:device.n])

Base.eltype(::Type{GPUBlockLocalVector{V}}) where {V} = eltype(V)
Base.ndims(::Type{GPUBlockLocalVector{V}}) where {V} = ndims(V)

function moveto(vec::Vector, device::GPUBlock)
return vec
end

function moveto(vec::Vector, task::GPUThreadBlock)
return copy(vec)
end

function moveto(vec::GPUBlockLocalVector, task::GPUThreadBlock)
return vec.data[task.tid]
end


struct GPUThread <: AbstractDevice
n::Int
end
GPUThread() = GPUThread(Threads.nthreads())
@kwdef struct VirtualGPUThread <: AbstractVirtualDevice
ex
n
end
function virtualize(ex, ::Type{GPUThread}, ctx)
sym = freshen(ctx, :gputhread)
push!(ctx.preamble, quote
$sym = $ex
end)
VirtualGPUThread(sym, virtualize(:($sym.n), Int, ctx))
end
lower(device::VirtualGPUThread, ctx::AbstractCompiler, ::DefaultStyle) =
something(device.ex, :(GPUThread($(ctx(device.n)))))

FinchNotation.finch_leaf(device::VirtualGPUThread) = virtual(device)

struct GPUThreadThread{Parent} <: AbstractTask
tid::Int
dev::GPUThread
parent::Parent
end
get_device(task::GPUThreadThread) = task.device
get_task(task::GPUThreadThread) = task.parent
struct VirtualGPUThreadThread <: AbstractVirtualTask
tid
dev::VirtualGPUThread
parent
end
function virtualize(ex, ::Type{GPUThreadThread{Parent}}, ctx) where {Parent}
VirtualGPUThreadThread(
virtualize(:($sym.tid), Int, ctx),
virtualize(:($sym.dev), GPUThread, ctx),
virtualize(:($sym.parent), Parent, ctx)
)
end
lower(task::VirtualGPUThreadThread, ctx::AbstractCompiler, ::DefaultStyle) = :(GPUThreadThread($(ctx(task.tid)), $(ctx(task.dev)), $(ctx(task.parent))))
FinchNotation.finch_leaf(device::VirtualGPUThreadThread) = virtual(device)
virtual_get_device(task::VirtualGPUThreadThread) = task.device
virtual_get_task(task::VirtualGPUThreadThread) = task.parent

struct GPUThreadLocalMemory
device::GPUThread
end
function moveto(vec::V, mem::GPUThreadLocalMemory) where {V <: Vector}
GPUThreadLocalVector{V}(mem.device, [copy(vec) for _ in 1:mem.device.n])
end

struct GPUThreadLocalVector{V}
device::GPUThread
data::Vector{V}
end

GPUThreadLocalVector{V}(device::GPUThread) where {V} =
GPUThreadLocalVector{V}(device, [V([]) for _ in 1:device.n])

Base.eltype(::Type{GPUThreadLocalVector{V}}) where {V} = eltype(V)
Base.ndims(::Type{GPUThreadLocalVector{V}}) where {V} = ndims(V)

function moveto(vec::Vector, device::GPUThread)
return vec
end

function moveto(vec::Vector, task::GPUThreadThread)
return copy(vec)
end

function moveto(vec::GPUThreadLocalVector, task::GPUThreadThread)
return vec.data[task.tid]
end
35 changes: 35 additions & 0 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,41 @@ function virtual_call(::typeof(parallel), ctx, ext, device)
end
end

# Begin GPU part
gpublock_parallel(dim, device=GPUBlock(nthreads())) = ParallelDimension(dim, device)

function virtual_call(::typeof(gpublock_parallel), ctx, ext)
if ext.kind === virtual
n = cache!(ctx, :n, value(:(Threads.nthreads()), Int))
virtual_call(gpublock_parallel, ctx, ext, finch_leaf(VirtualGPUBlock(nothing, n)))
end
end

function virtual_call(::typeof(gpublock_parallel), ctx, ext, device)
device = resolve(device, ctx) #TODO this feels broken
if ext.kind === virtual
ParallelDimension(ext.val, device)
end
end

gputhread_parallel(dim, device=GPUThread(nthreads())) = ParallelDimension(dim, device)

function virtual_call(::typeof(gputhread_parallel), ctx, ext)
if ext.kind === virtual
n = cache!(ctx, :n, value(:(Threads.nthreads()), Int))
virtual_call(gputhread_parallel, ctx, ext, finch_leaf(VirtualGPUThread(nothing, n)))
end
end

function virtual_call(::typeof(gputhread_parallel), ctx, ext, device)
device = resolve(device, ctx) #TODO this feels broken
if ext.kind === virtual
ParallelDimension(ext.val, device)
end
end
# End GPU part


virtual_uncall(ext::ParallelDimension) = call(parallel, ext.ext, ext.device)

FinchNotation.finch_leaf(x::ParallelDimension) = virtual(x)
Expand Down
15 changes: 14 additions & 1 deletion src/environment.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ end
preamble::Vector{Any} = []
epilogue::Vector{Any} = []
task = VirtualSerial()
topPreamble::Vector{Any} = []
top::Bool = true
end

"""
Expand Down Expand Up @@ -66,22 +68,33 @@ Call f on a subcontext of `ctx` and return the result. Variable bindings,
preambles, and epilogues defined in the subcontext will not escape the call to
contain.
"""
function contain(f, ctx::AbstractCompiler, task=nothing)
function contain(f, ctx::AbstractCompiler, task=nothing, top=false)
ctx_2 = shallowcopy(ctx)
ctx_2.topPreamble = ctx.topPreamble
ctx_2.top = top
ctx_2.task = something(task, ctx.task)
preamble = Expr(:block)
ctx_2.preamble = preamble.args
epilogue = Expr(:block)
ctx_2.epilogue = epilogue.args
body = f(ctx_2)
if (ctx.top && length(ctx_2.topPreamble) > 0)
toppre2 = ctx_2.topPreamble
toppre = :($(toppre2...),)
else
toppre = Expr(:block)
end

if epilogue == Expr(:block)
return quote
$toppre
$preamble
$body
end
else
res = freshen(ctx_2, :res)
return quote
$toppre
$preamble
$res = $(contain_epilogue_helper(body, epilogue))
$epilogue
Expand Down
Loading
Loading