1+ # -*- coding: utf-8 -*-
2+
3+ """
4+ Efficiency of writing "sparse" semantics for Adagrad
5+ ====================================================
6+
7+ `Issue 1369 <https://github.com/pytorch/pytorch/issues/1369>`__ discussed the additional lines of code
8+ that were introduce while writing "sparse" semantics for Adagrad.
9+ But really the code doesn't use sparsity as a compression and optimization technique,
10+ it wants to use masked semantics. We worked around this by introducing one-off semantics and operators
11+ that encode this behavior while forcing users to be aware of storage details such as indices and values.
12+
13+ In particular we'll point out when sparsity is used as a semantic extension, i.e. unspecified values are not zero
14+ and when it is just used to compress zeros.
15+ We'll also compare and contrast this with equivalent code written using MaskedTensor.
16+ In the end the code snippets are repeat without additional comments to show the difference in brevity.
17+
18+ """ "
19+
20+ import torch
21+ from torch .masked .maskedtensor import masked_tensor
22+
23+ ######################################################################
24+ # Original sparse implementation
25+ # ------------------------------
26+ #
27+ # First, let's look at the current implementation of
28+ # `Adagrad (functional) <https://github.com/pytorch/pytorch/blob/6c2f235d368b697072699e5ca9485fd97d0b9bcc/torch/optim/_functional.py#L16-L51>`__
29+ #
30+
31+ def _make_sparse (grad , grad_indices , values ):
32+ size = grad .size ()
33+ if grad_indices .numel () == 0 or values .numel () == 0 :
34+ return torch .empty_like (grad )
35+ return torch .sparse_coo_tensor (grad_indices , values , size )
36+
37+ # Some hyperparameters
38+ eps = 1e-10
39+ clr = 0.1
40+
41+ # We don't support sparse gradients
42+ param = torch .arange (8 ).reshape (2 , 4 ).float ()
43+ i = torch .tensor ([[0 , 1 , 1 ],
44+ [2 , 0 , 2 ]])
45+ v = torch .tensor ([3 , 4 , 5 ], dtype = torch .float32 )
46+ grad = torch .sparse_coo_tensor (i , v , [2 , 4 ])
47+ state_sum = torch .full_like (param , 0.5 ) # initial value for state sum
48+
49+ print ("param:\n " , param )
50+ print ("grad:\n " , grad .to_dense ())
51+ print ("state_sum:\n " , state_sum )
52+
53+ ######################################################################
54+ #
55+
56+ state_sum = torch .full_like (param , 0.5 ) # initial value for state sum
57+ print (state_sum )
58+
59+ grad = grad .coalesce () # the update is non-linear so indices must be unique
60+ grad_indices = grad ._indices ()
61+ grad_values = grad ._values ()
62+
63+ # pow(2) has the same semantics for both sparse and dense memory layouts since 0^2 is zero
64+ state_sum .add_ (_make_sparse (grad , grad_indices , grad_values .pow (2 )))
65+ # We take care to make std sparse, even though state_sum clearly is not.
66+ # This means that we're only applying the gradient to parts of the state_sum
67+ # for which it is specified. This even drives the point home a lot more that
68+ # the passed gradient is not sparse, but masked.
69+ std = state_sum .sparse_mask (grad )
70+ print ("state_sum:\n " , state_sum )
71+ print ("std:\n " , std .to_dense ())
72+
73+ ######################################################################
74+ # This is where we have a very important divergence.
75+ # The addition of eps should technically be applied to all values, but instead is only applied to specified values.
76+ # Here we're using sparsity as a semantic extension and to enforce a certain pattern of defined and undefined values.
77+ # If parts of the values of the gradient are zero they are still included if materialized.
78+ # Even though they could be compressed by other sparse storage layouts.
79+ # This is technically quite brittle even though someone could argue that eps is always very small.
80+ #
81+ # Moreover, an implementation add_ for sparsity as a storage layout and compression scheme should cause densification,
82+ # but we force it not to.
83+ # For this one-off case it is fine until we want to introduce new compression schemes
84+ # such as CSR, BSR or 2:4 block sparsity. We'll then need to introduce separate Tensor types for each
85+ # and write variations for gradients compressed using different storage formats.
86+ #
87+
88+ # We currently dodge all these concerns using the private method values.
89+ std_values = std ._values ().sqrt_ ().add_ (eps )
90+
91+ # We currently don't support div for sparse Tensors because zero / zero is
92+ # not well defined. For a MaskedTensor undefined / undefined is undefined.
93+ param .add_ (_make_sparse (grad , grad_indices , grad_values / std_values ), alpha = - clr )
94+ print ("param:\n " , param )
95+
96+ ######################################################################
97+ # MaskedTensor sparse implementation
98+ # ----------------------------------
99+ #
100+ # We've been conflating sparsity as an optimization with sparsity as a semantic extension to PyTorch.
101+ # MaskedTensor proposes to call the semantic extension through sparsity masked.
102+ # Currently we can't have dense semantics with sparse storage or masked semantics with dense storage, while
103+ # MaskedTensor fixes that because it separates the storage from the semantics.
104+ # Consider the above example using a masked gradient:
105+ #
106+
107+ # Create an entirely new set of parameters to avoid errors
108+ param2 = torch .arange (8 ).reshape (2 , 4 ).float ()
109+ state_sum2 = torch .full_like (param , 0.5 ) # initial value for state sum
110+
111+ mask = (grad .to_dense () != 0 ).to_sparse ()
112+ masked_grad = masked_tensor (grad , mask )
113+ print ("masked_grad:\n " , masked_grad )
114+
115+ ######################################################################
116+ #
117+
118+ state_sum2 = state_sum2 + masked_grad .pow (2 ).data ()
119+ std2 = masked_tensor (state_sum2 .to_sparse (), mask )
120+
121+ # Let's print both this version and the regular version for easier comparison
122+ print ("state_sum:\n " , state_sum )
123+ print ("std:\n " , std )
124+ print ("state_sum2:\n " , state_sum2 )
125+ print ("std2:\n " , std2 )
126+
127+ ######################################################################
128+ #
129+
130+ # We can add support for in-place operations later. Notice how this doesn't
131+ # need to access any storage internals and is in general a lot shorter
132+ std2 = std2 .sqrt ().add (eps )
133+
134+ print ("std:\n " , std )
135+ print ("std2:\n " , std2 )
136+
137+ # .data() indeed returns a sparse tensor
138+ param2 = param2 .add ((masked_grad / std2 ).data (), alpha = - clr )
139+ print ("param2:\n " , param2 )
140+
141+ ######################################################################
142+ # Conclusion: Difference in code
143+ # ------------------------------
144+ #
145+ # For reference, this is the regular, dense code path without masked gradients or sparsity:
146+ # ::
147+ #
148+ # state_sum.addcmul_(grad, grad, value=1)
149+ # std = state_sum.sqrt().add_(eps)
150+ # param.addcdiv_(grad, std, value=-clr)
151+ #
152+ # The vanilla tensor implementation for sparse is:
153+ #
154+
155+ grad = grad .coalesce () # the update is non-linear so indices must be unique
156+ grad_indices = grad ._indices ()
157+ grad_values = grad ._values ()
158+ size = grad .size ()
159+
160+ state_sum .add_ (_make_sparse (grad , grad_indices , grad_values .pow (2 )))
161+ std = state_sum .sparse_mask (grad )
162+ std_values = std ._values ().sqrt_ ().add_ (eps )
163+ param .add_ (_make_sparse (grad , grad_indices , grad_values / std_values ), alpha = - clr )
164+
165+ ######################################################################
166+ # while MaskedTensor minimizes the code to the snippet:
167+ #
168+
169+ state_sum2 = state_sum2 + masked_grad .pow (2 ).data ()
170+ std2 = masked_tensor (state_sum2 .to_sparse (), mask )
171+ std2 = std2 .sqrt ().add (eps )
172+ param2 = param2 .add ((masked_grad / std2 ).data (), alpha = - clr )
173+
174+ ######################################################################
175+ # And for good measure, let's make sure the parameters match:
176+ #
177+
178+ print ("param:\n " , param )
179+ print ("param2:\n " , param2 )
180+
0 commit comments