Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Progress: change hint propagation for user defined struct #367

Open
georgematheos opened this issue Feb 13, 2021 · 0 comments
Open

Progress: change hint propagation for user defined struct #367

georgematheos opened this issue Feb 13, 2021 · 0 comments

Comments

@georgematheos
Copy link
Contributor

Caching some work from the hackathon today. To enable propagating change hints when constructing objects which hold several values and unpacking values from these objects, we introduce a Construct(type) generative function, and a GetField(type, fieldname) generative function.

Usage example

@gen (static) function foo()
    x ~ normal(0, 1)
    y ~ normal(0, 1)
    return Construct(Point2f0)(x, y)
end

@gen (static) function bar(p::Point2f0)
    x ~ GetField(Point2f0, :x)(p::Point2f0)
    return long_computation(x)
end

We could add sugar so that if the user writes Point2f0(x, y), it is automatically converted to Construct(Point2f0)(x, y), and if the user writes point.x, it is converted to GetField(Point2f0, :x)(point)`.

Here is an untested implementation:

struct StructDiff{type, T} <: Diff
    diffs::T
    function StructDiff{type}(diffs::T) where {T <: Tuple{Vararg{<:Diff}}}
        return new{type, T}(diffs)
    end
end
function _get_diff(diff::StructDiff, fieldname::Symbol)
    return _static_get_diff(diff, Val(fieldname))
end
@generated function _static_get_diff(diff::StructDiff{type}, ::Val{fieldname}) where {type, fieldname}
    idx = findfirst(fieldnames(type) .== fieldname)
    return :(diff.diffs[$idx])
end

struct Construct{type} <: CustomUpdateGF{type, type} end
Construct(type) = Construct{type}()

function Gen.apply_with_state(::Construct{type}, args) where {type}
    obj = type(args...)
    return (obj, obj)
end

function Gen.update_with_state(::Construct{type}, obj, args, argdiffs::Tuple{Vararg{NoChange}}) where {type}
    return (
        obj,
        obj,
        StructDiff{type}(argdiffs)
    )
end

function Gen.update_with_state(::Construct{type}, obj, args, argdiffs::Tuple) where {type}
    new_obj = type(args...)
    return (
        new_obj,
        new_obj,
        StructDiff{type}(argdiffs)
    )
end

struct GetField{type, fieldname, valtype} <: CustomUpdateGF{valtype, Nothing}
    function GetField{type, fieldname}() where {type, fieldname}
        return new{type, fieldname, fieldtype(type, fieldname)}()
    end
end
GetField(type, fieldname) = GetField{type, fieldname}()

function Gen.apply_with_state(::GetField{type, fieldname}, (obj,)) where {type, fieldname}
    return (getproperty(obj, fieldname), nothing)
end

function Gen.update_with_state(::GetField{type, fieldname}, _, (obj,), (diff,)::Tuple{StructDiff{type}}) where {type, fieldname}
    return (nothing, getproperty(obj, fieldname), _get_diff(diff, fieldname))
end
@marcoct marcoct added this to Projects in Gen Summer 2021 Hackathon Jul 13, 2021
@ztangent ztangent moved this from Projects to In Progress in Gen Summer 2021 Hackathon Jul 13, 2021
@ztangent ztangent moved this from In Progress to Waiting for Review in Gen Summer 2021 Hackathon Jul 14, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Gen Summer 2021 Hackathon
Waiting for Review
Development

No branches or pull requests

1 participant