- 
                Notifications
    
You must be signed in to change notification settings  - Fork 9.8k
 
Closed
Description
def _get_orthogonal_init_weights(weights):
fan_out = weights.size(0)
fan_in = weights.size(1) * weights.size(2) * weights.size(3)
u, _, v = svd(normal(0.0, 1.0, (fan_out, fan_in)), full_matrices=False)
if u.shape == (fan_out, fan_in):
return torch.Tensor(u.reshape(weights.size()))
else:
return torch.Tensor(v.reshape(weights.size()))
Why do the above operation?
Metadata
Metadata
Assignees
Labels
No labels