<a href="https://colab.research.google.com/github/tsakailab/sandbox/blob/master/pytorch_nuclear_svt_cuda.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    print(torch.cuda.memory_allocated())
    print(torch.cuda.memory_cached())
    torch.cuda.empty_cache()

Tesla P100-PCIE-16GB
0
0


## Backpropable nuclear norm and singular value thresholding

In [0]:
class _eigpack(torch.nn.Module):
    def __init__(self):
        super(_eigpack, self).__init__()

    def compute_sv(self, a, rndvec=None, scale=1.):
        if rndvec is None:
            #rndvec = torch.ones(a.shape[-1])
            rndvec = torch.rand(a.shape[-1], device=device) * scale
        s, v = torch.symeig(torch.matmul(torch.transpose(a,-2,-1),a)+torch.diag(rndvec), eigenvectors=True)
        return torch.sqrt(torch.abs(s-rndvec)), v

    def svd(self, a, rndvec=None, scale=1.):
        s, v = self.compute_sv(a, rndvec, scale)
        u = torch.matmul(torch.matmul(a,v),torch.diag_embed(1./s))
        return u, s, v


class NuclearLoss(_eigpack):
    def __init__(self):
        super(NuclearLoss, self).__init__()

    def forward(self, a, rndvec=None, scale=1.):
        return torch.sum(self.compute_sv(a, rndvec, scale)[0])


class SVT(_eigpack):
    def __init__(self, prox=None):
        super(SVT, self).__init__()
        if prox is None:
            self.prox = lambda z, th: z.sign() * (z.abs() - th).max(torch.tensor(0, device=device).float())
        else:
            self.prox = prox

    def forward(self, q, th, rndvec=None):
        u, s, v = self.svd(q, rndvec=rndvec, scale=th)
        return torch.matmul(torch.matmul(u, torch.diag_embed(self.prox(s,th))), torch.transpose(v, -2, -1))


## Operation check

In [0]:
D = torch.autograd.Variable(torch.randn(128,40000,60, device=device), requires_grad=True)
#D = torch.autograd.Variable(torch.randn(2,5,3, device=device), requires_grad=True)

In [7]:
# SVT
from time import time
t0 = time()
svt = SVT().to(device)
Dt = svt(D, 1.)
print('done in %.2fms' % ((time() - t0)*1000))

# check the singular values
print(torch.svd(D)[1])
print(torch.svd(Dt)[1])

done in 159.58ms
tensor([[207.8896, 207.2940, 206.6304,  ..., 193.8755, 193.4216, 193.2918],
        [208.0980, 206.8282, 206.2639,  ..., 193.5231, 193.4404, 192.4615],
        [207.0473, 206.6014, 206.4265,  ..., 193.2453, 193.1632, 192.8814],
        ...,
        [207.9840, 207.3488, 206.4899,  ..., 193.4875, 193.1573, 192.5980],
        [207.0945, 206.9053, 206.4542,  ..., 193.6834, 192.9611, 192.5341],
        [207.5618, 206.9408, 206.6870,  ..., 193.6608, 193.2153, 192.4544]],
       device='cuda:0', grad_fn=<SvdBackward>)
tensor([[206.8894, 206.2939, 205.6306,  ..., 192.8754, 192.4216, 192.2918],
        [207.0979, 205.8280, 205.2638,  ..., 192.5231, 192.4404, 191.4616],
        [206.0472, 205.6013, 205.4265,  ..., 192.2453, 192.1633, 191.8815],
        ...,
        [206.9840, 206.3488, 205.4898,  ..., 192.4875, 192.1573, 191.5981],
        [206.0946, 205.9054, 205.4542,  ..., 192.6833, 191.9612, 191.5342],
        [206.5618, 205.9409, 205.6869,  ..., 192.6607, 192.2154, 191.4544

In [8]:
# nuclear norm and its backprop
from time import time
t0 = time()
loss_nu = NuclearLoss().to(device)
loss = loss_nu(D)
loss.backward(retain_graph=True)
print('done in %.2fms' % ((time() - t0)*1000))
print(loss.item())

done in 131.44ms
1535599.75


In [9]:
# gradient descent vs. proximal operation
print(torch.svd(D-D.grad)[1])
print(torch.svd(Dt)[1])

print(D.grad)
print(D-Dt)

tensor([[205.8895, 205.2940, 204.6305,  ..., 191.8755, 191.4216, 191.2918],
        [206.0981, 204.8280, 204.2637,  ..., 191.5230, 191.4404, 190.4617],
        [205.0473, 204.6013, 204.4264,  ..., 191.2453, 191.1633, 190.8814],
        ...,
        [205.9840, 205.3488, 204.4898,  ..., 191.4875, 191.1573, 190.5981],
        [205.0945, 204.9054, 204.4541,  ..., 191.6834, 190.9611, 190.5342],
        [205.5617, 204.9408, 204.6870,  ..., 191.6608, 191.2152, 190.4542]],
       device='cuda:0', grad_fn=<SvdBackward>)
tensor([[206.8894, 206.2939, 205.6306,  ..., 192.8754, 192.4216, 192.2918],
        [207.0979, 205.8280, 205.2638,  ..., 192.5231, 192.4404, 191.4616],
        [206.0472, 205.6013, 205.4265,  ..., 192.2453, 192.1633, 191.8815],
        ...,
        [206.9840, 206.3488, 205.4898,  ..., 192.4875, 192.1573, 191.5981],
        [206.0946, 205.9054, 205.4542,  ..., 192.6833, 191.9612, 191.5342],
        [206.5618, 205.9409, 205.6869,  ..., 192.6607, 192.2154, 191.4544]],
       device