<a href="https://colab.research.google.com/github/tsakailab/sandbox/blob/master/pytorch_nuclear_svt_cuda_wowa.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    print(torch.cuda.memory_allocated())
    print(torch.cuda.memory_cached())
    torch.cuda.empty_cache()

Tesla P100-PCIE-16GB
0
0


## Nuclear norm and singular value thresholding

In [0]:
soft = lambda z, th: z.sign() * (z.abs() - th).max(torch.tensor(0., device=device))

class NuclearLoss(torch.nn.Module):
    def __init__(self, lw=torch.tensor(1.0, device=device)):
        super(NuclearLoss, self).__init__()
        self.lw = torch.nn.Parameter(lw, requires_grad=lw.requires_grad)

    def forward(self, input):
        return torch.sum(torch.svd(input)[1])*self.lw


class SVT(torch.nn.Module):
    def __init__(self, prox=None):
        super(SVT, self).__init__()
        if prox is None:
            self.prox = soft
        else:
            self.prox = prox

    def forward(self, q, th):
        u, s, v = torch.svd(q)
        return torch.matmul(torch.matmul(u, torch.diag_embed(self.prox(s,th))), torch.transpose(v, -2, -1))


## Operation check

In [0]:
torch.cuda.empty_cache()
D = torch.autograd.Variable(torch.randn(32,40000,60, device=device), requires_grad=True)
#D = torch.autograd.Variable(torch.randn(2,5,3, device=device), requires_grad=True)

In [4]:
# SVT
from time import time
t0 = time()
svt = SVT().to(device)
Dt = svt(D, 200.0)
#Dt = svt(Dt, 1.0)
print('done in %.2fms' % ((time() - t0)*1000))

# check the singular values
print(torch.svd(D)[1])
print(torch.svd(Dt)[1])

done in 2757.77ms
tensor([[207.2374, 206.8710, 206.6576,  ..., 193.6122, 193.4442, 193.0820],
        [207.4635, 206.8243, 206.5444,  ..., 193.6313, 193.3822, 192.7170],
        [207.2963, 206.6805, 206.3594,  ..., 193.7258, 193.2210, 192.9441],
        ...,
        [207.5545, 207.4174, 206.4594,  ..., 193.5102, 193.2692, 192.7632],
        [207.3986, 207.1288, 206.5043,  ..., 193.3150, 193.1795, 192.6274],
        [207.2240, 207.1319, 206.5300,  ..., 193.9548, 193.0745, 192.7691]],
       device='cuda:0', grad_fn=<SvdBackward>)
tensor([[7.2374e+00, 6.8710e+00, 6.6576e+00,  ..., 3.9930e-07, 3.9000e-07,
         3.8349e-07],
        [7.4635e+00, 6.8243e+00, 6.5444e+00,  ..., 4.0930e-07, 4.0408e-07,
         3.9271e-07],
        [7.2963e+00, 6.6805e+00, 6.3594e+00,  ..., 4.0619e-07, 3.8825e-07,
         3.7562e-07],
        ...,
        [7.5545e+00, 7.4174e+00, 6.4594e+00,  ..., 4.0416e-07, 3.9641e-07,
         3.4358e-07],
        [7.3986e+00, 7.1288e+00, 6.5043e+00,  ..., 4.3559e-07, 4

In [5]:
# nuclear norm and its backprop
from time import time
t0 = time()
loss_nu = NuclearLoss(lw=torch.tensor(2.0)).to(device)
loss = loss_nu(D)
loss.backward(retain_graph=True)
print('done in %.2fms' % ((time() - t0)*1000))
print(loss.item())

done in 1916.96ms
767889.8125


In [6]:
# gradient descent vs. proximal operation
#lr = 0.1
#print(torch.svd(D - lr * D.grad)[1])
#print(lr * torch.svd(Dt)[1])

print(D.grad)
print(D-Dt)
print((D-Dt)/D.grad)

tensor([[[ 2.4121e-02, -1.4234e-03, -4.9831e-03,  ..., -1.3512e-02,
          -1.3425e-02, -1.2416e-02],
         [ 8.7790e-03, -8.2577e-03,  3.5128e-03,  ...,  9.8820e-03,
           2.2432e-03,  4.9524e-04],
         [ 1.3971e-03, -1.0859e-02,  1.0679e-02,  ..., -4.3868e-03,
           6.6451e-03,  5.6847e-03],
         ...,
         [ 5.7985e-03,  2.3280e-02, -1.0705e-02,  ..., -1.5412e-02,
          -1.2972e-02, -1.1951e-03],
         [ 8.7798e-03, -6.0223e-03, -6.3771e-03,  ...,  1.9041e-03,
          -2.1011e-02,  7.4905e-04],
         [-7.0064e-03,  1.2021e-02, -5.3020e-04,  ...,  7.7135e-03,
           2.3314e-03,  1.0743e-03]],

        [[-1.4096e-03,  1.3541e-02, -1.9872e-02,  ...,  1.6624e-02,
           1.7633e-03,  2.0771e-03],
         [-8.6391e-03,  3.7311e-03, -1.6744e-02,  ...,  9.6824e-03,
          -3.1454e-03,  4.7155e-03],
         [-7.5369e-03,  1.3509e-02, -8.2720e-03,  ..., -1.6986e-02,
          -1.5351e-02, -6.3151e-03],
         ...,
         [ 1.7602e-02,  7