@@ -279,10 +279,25 @@ export class MathBackendWebGL extends KernelBackend {
279279 }
280280 const dataId = { } ;
281281 this . texData . set (
282- dataId , { shape, dtype, values, usage : TextureUsage . UPLOAD } ) ;
282+ dataId ,
283+ { shape, dtype, values, usage : TextureUsage . UPLOAD , refCount : 1 } ) ;
283284 return dataId ;
284285 }
285286
287+ /** Increase refCount of a `TextureData`. */
288+ incRef ( dataId : DataId ) : void {
289+ const texData = this . texData . get ( dataId ) ;
290+ texData . refCount ++ ;
291+ }
292+
293+ /** Decrease refCount of a `TextureData`. */
294+ decRef ( dataId : DataId ) : void {
295+ if ( this . texData . has ( dataId ) ) {
296+ const texData = this . texData . get ( dataId ) ;
297+ texData . refCount -- ;
298+ }
299+ }
300+
286301 move ( dataId : DataId , values : BackendValues , shape : number [ ] , dtype : DataType ) :
287302 void {
288303 if ( env ( ) . getBool ( 'DEBUG' ) ) {
@@ -294,12 +309,31 @@ export class MathBackendWebGL extends KernelBackend {
294309 `Please use tf.complex(real, imag).` ) ;
295310 }
296311 this . texData . set (
297- dataId , { shape, dtype, values, usage : TextureUsage . UPLOAD } ) ;
312+ dataId ,
313+ { shape, dtype, values, usage : TextureUsage . UPLOAD , refCount : 1 } ) ;
314+ }
315+
316+ disposeIntermediateTensorInfo ( tensorInfo : TensorInfo ) : void {
317+ const dataId = tensorInfo . dataId ;
318+
319+ if ( this . texData . has ( dataId ) ) {
320+ const textureData = this . texData . get ( dataId ) ;
321+
322+ textureData . refCount -- ;
323+
324+ if ( textureData . refCount < 1 ) {
325+ this . disposeData ( dataId ) ;
326+ }
327+ }
298328 }
299329
300330 readSync ( dataId : DataId ) : BackendValues {
301331 const texData = this . texData . get ( dataId ) ;
302332 const { values, dtype, complexTensors, slice, shape, isPacked} = texData ;
333+
334+ // The presence of `slice` indicates this tensor is a shallow slice of a
335+ // different tensor, and is using that original tensor's texture. Run
336+ // `clone` in order to copy that texture and read from it.
303337 if ( slice != null ) {
304338 let program ;
305339 if ( isPacked ) {
@@ -310,7 +344,7 @@ export class MathBackendWebGL extends KernelBackend {
310344 const res =
311345 this . runWebGLProgram ( program , [ { dataId, shape, dtype} ] , dtype ) ;
312346 const data = this . readSync ( res . dataId ) ;
313- this . disposeData ( res . dataId ) ;
347+ this . disposeIntermediateTensorInfo ( res ) ;
314348 return data ;
315349 }
316350 if ( values != null ) {
@@ -348,6 +382,9 @@ export class MathBackendWebGL extends KernelBackend {
348382 const texData = this . texData . get ( dataId ) ;
349383 const { values, shape, slice, dtype, complexTensors, isPacked} = texData ;
350384
385+ // The presence of `slice` indicates this tensor is a shallow slice of a
386+ // different tensor, and is using that original tensor's texture. Run
387+ // `clone` in order to copy that texture and read from it.
351388 if ( slice != null ) {
352389 let program ;
353390 if ( isPacked ) {
@@ -358,7 +395,7 @@ export class MathBackendWebGL extends KernelBackend {
358395 const res =
359396 this . runWebGLProgram ( program , [ { dataId, shape, dtype} ] , dtype ) ;
360397 const data = this . read ( res . dataId ) ;
361- this . disposeData ( res . dataId ) ;
398+ this . disposeIntermediateTensorInfo ( res ) ;
362399 return data ;
363400 }
364401
@@ -408,7 +445,7 @@ export class MathBackendWebGL extends KernelBackend {
408445 vals = this . gpgpu . downloadFloat32MatrixFromBuffer ( buffer , size ) ;
409446 }
410447 if ( tmpDownloadTarget != null ) {
411- this . disposeData ( tmpDownloadTarget . dataId ) ;
448+ this . disposeIntermediateTensorInfo ( tmpDownloadTarget ) ;
412449 }
413450 const dTypeVals = this . convertAndCacheOnCPU ( dataId , vals ) ;
414451
@@ -454,7 +491,7 @@ export class MathBackendWebGL extends KernelBackend {
454491 tmpData . texture , ...tex_util . getDenseTexShape ( shape ) )
455492 . subarray ( 0 , size ) ;
456493
457- this . disposeData ( tmpTarget . dataId ) ;
494+ this . disposeIntermediateTensorInfo ( tmpTarget ) ;
458495
459496 return vals ;
460497 }
@@ -474,7 +511,7 @@ export class MathBackendWebGL extends KernelBackend {
474511 . downloadByteEncodedFloatMatrixFromOutputTexture (
475512 tmpData . texture , tmpData . texShape [ 0 ] , tmpData . texShape [ 1 ] )
476513 . subarray ( 0 , size ) ;
477- this . disposeData ( output . dataId ) ;
514+ this . disposeIntermediateTensorInfo ( output ) ;
478515
479516 return vals ;
480517 }
@@ -1820,21 +1857,20 @@ export class MathBackendWebGL extends KernelBackend {
18201857 ! reshapeWillBeExpensive ) {
18211858 const targetShape = isChannelsLast ? xShape [ 0 ] * xShape [ 1 ] * xShape [ 2 ] :
18221859 xShape [ 0 ] * xShape [ 2 ] * xShape [ 3 ] ;
1823- const xReshaped = this . reshape ( x , [ 1 , targetShape , convInfo . inChannels ] ) ;
1860+ const xReshaped = reshape ( x , [ 1 , targetShape , convInfo . inChannels ] ) ;
18241861 const filterReshaped =
1825- this . reshape ( filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) ;
1826-
1827- return this . reshape < Rank . R4 > (
1828- this . fusedBatchMatMul ( {
1829- a : xReshaped as Tensor3D ,
1830- b : filterReshaped as Tensor3D ,
1831- transposeA,
1832- transposeB,
1833- bias,
1834- activation,
1835- preluActivationWeights
1836- } ) ,
1837- convInfo . outShape ) ;
1862+ reshape ( filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) ;
1863+
1864+ const result = this . fusedBatchMatMul ( {
1865+ a : xReshaped as Tensor3D ,
1866+ b : filterReshaped as Tensor3D ,
1867+ transposeA,
1868+ transposeB,
1869+ bias,
1870+ activation,
1871+ preluActivationWeights
1872+ } ) ;
1873+ return reshape ( result , convInfo . outShape ) ;
18381874 }
18391875
18401876 // Following optimization is specific to packed |x| with odd row count
@@ -1869,7 +1905,7 @@ export class MathBackendWebGL extends KernelBackend {
18691905 ( ) => `packed reshape ${ xTexData . shape } to ${
18701906 xReshaped . shape } isn't free`) ;
18711907 const filterReshaped =
1872- this . reshape ( filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) ;
1908+ reshape ( filter , [ 1 , convInfo . inChannels , convInfo . outChannels ] ) ;
18731909
18741910 const pointwiseConv = this . fusedBatchMatMul ( {
18751911 a : xReshaped as Tensor3D ,
@@ -2182,19 +2218,6 @@ export class MathBackendWebGL extends KernelBackend {
21822218 return result as Tensor5D ;
21832219 }
21842220
2185- reshape < R extends Rank > ( x : Tensor , shape : ShapeMap [ R ] ) : Tensor < R > {
2186- const texData = this . texData . get ( x . dataId ) ;
2187-
2188- if ( texData . isPacked && ! webgl_util . isReshapeFree ( x . shape , shape ) &&
2189- ! ( texData . texture !== null &&
2190- webgl_util . isReshapeFree ( texData . shape , shape ) ) ) {
2191- const info = this . packedReshape ( x , shape ) ;
2192- return engine ( ) . makeTensorFromDataId (
2193- info . dataId , info . shape , info . dtype ) as Tensor < R > ;
2194- }
2195- return backend_util . reshapeTensor ( x , shape ) ;
2196- }
2197-
21982221 resizeBilinear (
21992222 x : Tensor4D , newHeight : number , newWidth : number ,
22002223 alignCorners : boolean ) : Tensor4D {
@@ -2575,7 +2598,7 @@ export class MathBackendWebGL extends KernelBackend {
25752598 gpgpu_math . runProgram (
25762599 this . gpgpu , binary , inputsData , outputData , customSetup ) ;
25772600
2578- dataToDispose . forEach ( info => this . disposeData ( info . dataId ) ) ;
2601+ dataToDispose . forEach ( info => this . disposeIntermediateTensorInfo ( info ) ) ;
25792602
25802603 if ( shouldTimeProgram ) {
25812604 query = this . endTimer ( query ) ;
@@ -2586,7 +2609,7 @@ export class MathBackendWebGL extends KernelBackend {
25862609 if ( ! env ( ) . getBool ( 'WEBGL_LAZILY_UNPACK' ) && outData . isPacked &&
25872610 preventEagerUnpackingOfOutput === false ) {
25882611 const unpacked = this . unpackTensor ( output ) ;
2589- this . disposeData ( output . dataId ) ;
2612+ this . disposeIntermediateTensorInfo ( output ) ;
25902613 return unpacked ;
25912614 }
25922615 return output ;
@@ -2733,7 +2756,7 @@ export class MathBackendWebGL extends KernelBackend {
27332756 texData . isPacked = outputTexData . isPacked ;
27342757 texData . usage = outputTexData . usage ;
27352758
2736- this . disposeData ( tempDenseInputHandle . dataId ) ;
2759+ this . disposeIntermediateTensorInfo ( tempDenseInputHandle ) ;
27372760 this . texData . delete ( encodedOutputTarget . dataId ) ;
27382761
27392762 // Once uploaded, don't store the values on cpu.
0 commit comments