@@ -59,3 +59,139 @@ bb0(%orig : $@callee_guaranteed (Float) -> Float):
5959// CHECK: [[EXTRACTED_VJP:%.*]] = differentiable_function_extract [vjp] [[DIFF_FN]] : $@differentiable(reverse) @callee_guaranteed (Float) -> Float
6060// CHECK: return [[EXTRACTED_VJP]] : $@callee_guaranteed (Float) -> (Float, @owned @callee_guaranteed (Float) -> Float)
6161// CHECK-LABEL: } // end sil function 'differentiable_function_extract_vjp_undefined'
62+
63+ // MARK: `convert_function` hoisting
64+
65+ // This should optimize down single partial_apply that escapes
66+ sil @differential_function_convert_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
67+ bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
68+ %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
69+
70+ %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
71+ %conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
72+
73+ %diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
74+ undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
75+ undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
76+ }
77+
78+ debug_value %diff_fn, let, name "f", argno 1
79+
80+ %conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
81+ %conv_orig = differentiable_function_extract [original] %conv_diff
82+ return %conv_orig
83+ }
84+
85+ // CHECK-LABEL: sil @differential_function_convert_single_use
86+ // CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
87+ // CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
88+ // CHECK: %[[PA:.*]] = partial_apply [callee_guaranteed] %[[THUNK]](%[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
89+ // CHECKL return %[[PA]] : $@callee_guaranteed (@in_guaranteed Float) -> Float
90+
91+ sil @blackhole : $(@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
92+
93+ // differentiable_function has multiple uses, so we cannot commute it with convert_function, check that all instructions are there
94+
95+ sil @differential_function_convert_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> @callee_guaranteed (@in_guaranteed Float) -> Float {
96+ bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
97+ %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
98+
99+ %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
100+ %conv_pa = convert_function %pa to $@callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <Float>
101+
102+ %diff_fn = differentiable_function [parameters 0] [results 0] %conv_pa with_derivative {
103+ undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (@in_guaranteed τ_0_0) -> Float for <τ_0_1>) for <Float, Float>,
104+ undef : $@callee_guaranteed @substituted <τ_0_0, τ_0_1> (@in_guaranteed τ_0_0) -> (Float, @owned @callee_guaranteed @substituted <τ_0_0> (Float) -> @out τ_0_0 for <τ_0_1>) for <Float, Float>
105+ }
106+
107+ debug_value %diff_fn, let, name "f", argno 1
108+
109+ %conv_diff = convert_function %diff_fn to $@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
110+ %conv_orig = differentiable_function_extract [original] %conv_diff
111+
112+ %blackhole = function_ref @blackhole : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
113+ apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed @substituted<T> (@in_guaranteed T) -> Float for <Float>) -> ()
114+
115+ return %conv_orig : $@callee_guaranteed (@in_guaranteed Float) -> Float
116+ }
117+
118+ // CHECK-LABEL: sil @differential_function_convert_multiple_use
119+ // CHECK: convert_function
120+ // CHECK: differentiable_function
121+ // CHECK: convert_function
122+ // CHECK: differentiable_function_extract
123+
124+ // MARK: `convert_escape_to_noescape` hoisting
125+
126+ sil @blackhole2 : $(@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
127+
128+ // Here we should be able to unfold partial_apply down to direct function call
129+
130+ sil @differential_function_noescape_single_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
131+ bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
132+ %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
133+
134+ %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
135+
136+ %diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
137+ undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
138+ undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
139+ }
140+
141+ debug_value %diff_fn, let, name "f", argno 1
142+
143+ %conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
144+ %conv_orig = differentiable_function_extract [original] %conv_diff
145+
146+ %arg = alloc_stack $Float
147+ apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
148+
149+ dealloc_stack %arg : $*Float
150+ strong_release %pa
151+
152+ %res = tuple ()
153+ return %res : $()
154+ }
155+
156+ // CHECK-LABEL: sil @differential_function_noescape_single_use
157+ // CHECK: bb0(%[[ORIG_FN:.*]] : $@convention(thin) (Float) -> Float, %[[THUNK:.*]] : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float)
158+ // CHECK: %[[TT_CONV:.*]] = thin_to_thick_function %[[ORIG_FN]] : $@convention(thin) (Float) -> Float to $@callee_guaranteed (Float) -> Float
159+ // CHECK: %[[ARG:.*]] = alloc_stack $Float
160+ // CHECK: apply %[[THUNK]](%[[ARG]], %[[TT_CONV]]) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
161+
162+
163+ // differentiable_function has multiple uses, so we cannot commute it with convert_escape_to_noescape, check that all instructions are there
164+
165+ sil @differential_function_noescape_multiple_use : $@convention(thin) (@convention(thin) (Float) -> Float, @convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float) -> () {
166+ bb0(%orig: $@convention(thin) (Float) -> Float, %thunk: $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float):
167+ %thick_orig = thin_to_thick_function %orig to $@callee_guaranteed (Float) -> Float
168+
169+ %pa = partial_apply [callee_guaranteed] %thunk(%thick_orig) : $@convention(thin) (@in_guaranteed Float, @guaranteed @callee_guaranteed (Float) -> Float) -> Float
170+
171+ %diff_fn = differentiable_function [parameters 0] [results 0] %pa with_derivative {
172+ undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (@in_guaranteed Float) -> Float),
173+ undef : $@callee_guaranteed (@in_guaranteed Float) -> (Float, @owned @callee_guaranteed (Float) -> @out Float)
174+ }
175+
176+ debug_value %diff_fn, let, name "f", argno 1
177+
178+ %conv_diff = convert_escape_to_noescape %diff_fn to $@noescape @differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float
179+ %conv_orig = differentiable_function_extract [original] %conv_diff
180+
181+ %arg = alloc_stack $Float
182+ apply %conv_orig(%arg) : $@noescape @callee_guaranteed (@in_guaranteed Float) -> Float
183+
184+ %blackhole = function_ref @blackhole2 : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
185+ apply %blackhole(%diff_fn) : $@convention(thin) (@differentiable(reverse) @callee_guaranteed (@in_guaranteed Float) -> Float) -> ()
186+
187+ dealloc_stack %arg : $*Float
188+ strong_release %pa
189+
190+ %res = tuple ()
191+ return %res : $()
192+ }
193+
194+ // CHECK-LABEL: sil @differential_function_noescape_multiple_use
195+ // CHECK: differentiable_function
196+ // CHECK: convert_escape_to_noescape
197+ // CHECK: differentiable_function_extract
0 commit comments