<a href="https://colab.research.google.com/github/tsakailab/sandbox/blob/master/pytorch_nuclear_svt_cuda.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    print(torch.cuda.memory_allocated())
    print(torch.cuda.memory_cached())
    torch.cuda.empty_cache()

Tesla P4
0
0


## Backpropable nuclear norm and singular value thresholding

In [0]:
_scale = 1e-6
class _eigpack(torch.nn.Module):
    def __init__(self):
        super(_eigpack, self).__init__()

    def compute_sv(self, a, rndvec=None, scale=_scale):
        if rndvec is None:
            #rndvec = torch.ones(a.shape[-1])
            rndvec = torch.sort(torch.rand(a.shape[-1], device=device) * scale)[0]
        s, v = torch.symeig(torch.matmul(torch.transpose(a,-2,-1),a)+torch.diag(rndvec), eigenvectors=True)
        return torch.sqrt(torch.abs(s-rndvec)), v

    def svd(self, a, rndvec=None, scale=_scale):
        s, v = self.compute_sv(a, rndvec, scale)
        u = torch.matmul(torch.matmul(a,v),torch.diag_embed(1./s))
        return u, s, v


class NuclearLoss(_eigpack):
    def __init__(self):
        super(NuclearLoss, self).__init__()

    def forward(self, a, rndvec=None, scale=_scale):
        return torch.sum(self.compute_sv(a, rndvec, scale)[0])


class SVT(_eigpack):
    def __init__(self, prox=None):
        super(SVT, self).__init__()
        if prox is None:
            self.prox = lambda z, th: z.sign() * (z.abs() - th).max(torch.tensor(0, device=device).float())
        else:
            self.prox = prox

    def forward(self, q, th, rndvec=None, scale=_scale):
        u, s, v = self.svd(q, rndvec=rndvec, scale=scale)
        return torch.matmul(torch.matmul(u, torch.diag_embed(self.prox(s,th))), torch.transpose(v, -2, -1))


## Operation check

In [0]:
torch.cuda.empty_cache()
#D = torch.autograd.Variable(torch.randn(32,40000,60, device=device), requires_grad=True)
D = torch.autograd.Variable(torch.randn(2,5,3, device=device), requires_grad=True)

In [77]:
# SVT
from time import time
t0 = time()
svt = SVT().to(device)
Dt = svt(D, 1.0)
print('done in %.2fms' % ((time() - t0)*1000))

# check the singular values
print(torch.svd(D)[1])
print(torch.svd(Dt)[1])

done in 9.51ms
tensor([[3.4395, 2.1921, 1.1903],
        [2.8233, 2.4473, 0.2503]], device='cuda:0', grad_fn=<SvdBackward>)
tensor([[2.4395e+00, 1.1921e+00, 1.9035e-01],
        [1.8233e+00, 1.4473e+00, 4.1797e-08]], device='cuda:0',
       grad_fn=<SvdBackward>)


In [78]:
# nuclear norm and its backprop
from time import time
t0 = time()
loss_nu = NuclearLoss().to(device)
loss = loss_nu(D)
loss.backward(retain_graph=True)
print('done in %.2fms' % ((time() - t0)*1000))
print(loss.item())

done in 3.98ms
12.342939376831055


In [79]:
# gradient descent vs. proximal operation
#lr = 0.1
#print(torch.svd(D - lr * D.grad)[1])
#print(lr * torch.svd(Dt)[1])

print(D.grad)
print(D-Dt)
print((D-Dt)/D.grad)

tensor([[[-0.2418, -1.0418,  0.2763],
         [-1.0853, -0.5995,  2.4643],
         [ 1.9324,  1.0628,  0.4216],
         [ 2.0036, -1.5964,  0.8768],
         [-0.1230, -1.9690, -1.3800]],

        [[ 0.3074,  2.4197,  0.7118],
         [-0.8496,  1.6001, -0.2949],
         [-2.6790, -0.0276,  0.3750],
         [ 0.9675,  0.6638, -0.2327],
         [-0.2654,  0.3788, -2.8656]]], device='cuda:0')
tensor([[[-0.0806, -0.3473,  0.0921],
         [-0.3618, -0.1998,  0.8214],
         [ 0.6441,  0.3543,  0.1405],
         [ 0.6679, -0.5321,  0.2923],
         [-0.0410, -0.6563, -0.4600]],

        [[-0.2431,  0.4898,  0.0691],
         [-0.3402,  0.4811, -0.1260],
         [-0.5896,  0.2689,  0.2727],
         [ 0.1465,  0.0599, -0.1632],
         [ 0.0695,  0.2711, -0.8783]]], device='cuda:0',
       grad_fn=<SubBackward0>)
tensor([[[ 0.3333,  0.3333,  0.3333],
         [ 0.3333,  0.3333,  0.3333],
         [ 0.3333,  0.3333,  0.3333],
         [ 0.3333,  0.3333,  0.3333],
         [ 0.33