<a href="https://colab.research.google.com/github/tsakailab/sandbox/blob/master/20201201_LowRankSparseLungSoundSeparation_Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(torch.cuda.get_device_name(torch.cuda.current_device()))
    print(torch.cuda.memory_allocated())
    print(torch.cuda.memory_cached())
    torch.cuda.empty_cache()

Tesla T4
0
0




In [None]:
class UNet(torch.nn.Module):

    def __init__(self, in_channels, out_channels0, out_channels1, out_channels2, out_channels3):
        super().__init__()
 
        def Convs(in_channels,out_channels):
            return torch.nn.Sequential(
                torch.nn.Conv1d(in_channels, out_channels, 3,padding=1),
                torch.nn.BatchNorm1d(out_channels),
                torch.nn.ReLU(inplace=True),
                torch.nn.Conv1d(out_channels, out_channels, 3,padding=1),
                torch.nn.BatchNorm1d(out_channels),
                torch.nn.ReLU(inplace=True)
            )   
 
        def upsample(in_channels, out_channels, scale_factor, mode):
            return torch.nn.Sequential(
                torch.nn.Upsample(scale_factor=scale_factor, mode=mode),
                torch.nn.Conv1d(in_channels, out_channels, 1, padding=0)
            )
 
        def ConvOut(in_channels,out_channels):
            return torch.nn.Sequential(
                torch.nn.Conv1d(in_channels, out_channels, 1, padding=0)
            )
        
        self.convs0enc = Convs(in_channels,out_channels0) # 1024
        self.convs1enc = Convs(out_channels0,out_channels1) # 512
        self.convs2enc = Convs(out_channels1,out_channels2) # 256
        self.convs3enc = Convs(out_channels2,out_channels3) # 128
        self.downsample = torch.nn.AvgPool1d(2)
        self.upsample32 = upsample(out_channels3, out_channels3, 2, 'bilinear') #torch.nn.Upsample(scale_factor=2,mode = 'bilinear')
        self.upsample21 = upsample(out_channels2, out_channels2, 2, 'bilinear') #torch.nn.Upsample(scale_factor=2,mode = 'bilinear')
        self.upsample10 = upsample(out_channels1, out_channels1, 2, 'bilinear') #torch.nn.Upsample(scale_factor=2,mode = 'bilinear')

        self.convs2dec = Convs(out_channels2 + out_channels2, out_channels2) # 256
        self.convs1dec = Convs(out_channels1 + out_channels1, out_channels1) # 512
        self.convs0dec = Convs(out_channels0 + out_channels0, out_channels0) # 1024
        self.conv_out = ConvOut(out_channels0, in_channels) # 1024

    def forward(self, x):
        #encoding
        convs0e = self.convs0enc(x) #x:(B,C,F)
        x = self.downsample(x)
        convs1e = self.convs1enc(x)
        x = self.downsample(x)
        convs2e = self.convs2enc(x)
        x = self.downsample(x)
        convs3e = self.convs3enc(x)

        x = self.upsample32(x)
        x = torch.cat([x, convs2e])
        x = self.convs2dec(x)
        x = self.upsample21(x)
        x = torch.cat([x, convs1e])
        x = self.convs1dec(x)
        x = self.upsample10(x)
        x = torch.cat([x, convs0e])
        x = self.convs0dec(x)
        s = self.conv_out(x)

        return s

In [2]:
## L1 loss
class L1LossFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, lw):
        ctx.save_for_backward(input, lw)
        return torch.sum(torch.abs(input)*lw)

    @staticmethod
    def backward(ctx, grad_output):
        input, lw = ctx.saved_tensors
        grad_input = grad_output.clone()
        return (input - soft(input, lw)) * grad_input, torch.abs(input) * grad_input

class L1Loss(torch.nn.Module):
    def __init__(self, lw=torch.tensor(1.0, device=device)):
        super(L1Loss, self).__init__()
        self.fn = L1LossFunc.apply
        self.lw = torch.nn.Parameter(lw, requires_grad=lw.requires_grad)

    def forward(self, input):
        return self.fn(input, self.lw)


## Nuclear loss
class NuclearLossFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, lw):
        u, s, v = torch.svd(input)
        ctx.save_for_backward(input, lw, u, s, v)
        return torch.sum(s*lw)

    @staticmethod
    def backward(ctx, grad_output):
        input, lw, u, s, v = ctx.saved_tensors
        grad_input = grad_output.clone()
        svt_input = torch.matmul(torch.matmul(u, torch.diag_embed(soft(s,lw))), torch.transpose(v, -2, -1))
        return (input - svt_input) * grad_input, s * grad_input

class NuclearLoss(torch.nn.Module):
    def __init__(self, lw=torch.tensor(1.0, device=device)):
        super(NuclearLoss, self).__init__()
        self.fn = NuclearLossFunc.apply
        self.lw = torch.nn.Parameter(lw, requires_grad=lw.requires_grad)

    def forward(self, input):
        return self.fn(input, self.lw)


anucloss = lambda x: torch.sum(torch.svd(x)[1])
al1loss = lambda x: torch.sum(torch.abs(x))

In [None]:
modelUnet = UNet(1, 6, 12, 24, 48)
model = lambda x: modelUnet(x)
#model = lambda x: modelUnet(x.T.view(as_imgseq)).view(DDseq.shape[0],-1).T

ln = 1.
ls = 0.1 # ls = 1./np.sqrt(max(m,n)) / 16
alpha 0.5
nucloss = NuclearLoss(lw=torch.tensor(alpha*ln))
l1loss = L1Loss(lw=torch.tensor(alpha*ls))

In [None]:
D = spectrogram

lr = {'Adam': 3e-6, 'SGD': 3e-6}      #learning rate

optimizerM = torch.optim.Adam(modelUnet.parameters(), lr = ['Adam'])
history_loss_Unet = []

S = model(D)
num_iter = 2
for iter in range(num_iter):
    total_loss = 0.
    loss = alpha*ln*anucloss(D-S) + alpha*ls*al1loss(S)
    optimizerM.zero_grad()
    loss.backward(retain_graph=True)
    optimizerM.step()
    total_loss += loss.item() / alpha

    history_loss_Unet.append(total_loss)

    if (iter+1) % display_step == 0:
        print('[{:3d}/{}]: loss = {:.4f},  '.format(iter+1, num_iter, total_loss))


num_iter = 1000
display_step = num_iter // 50
optimizerM = torch.optim.SGD(modelUnet.parameters(), lr = lr['SGD'])
for iter in range(num_iter):

    total_loss = 0.
    with torch.no_grad():
        Sk = model(D).detach()
    Lk = (D-Sk).detach()
    S = model(D)
    loss = nucloss(Lk + alpha * (Sk - S))
    optimizerM.zero_grad()
    loss.backward(retain_graph=True)
    optimizerM.step()
    total_loss += loss.item() / alpha

    loss = l1loss(S + alpha * (Sk - S))           # good
    #loss = l1loss(S)                              # excellent if not noisy
    #optimizerM = torch.optim.SGD(model.parameters(), lr = 5e-1)
    optimizerM.zero_grad()
    loss.backward(retain_graph=True)
    optimizerM.step()
    total_loss += loss.item() / alpha

    #total_loss = (anucloss(Lk) + al1loss(S)) / alpha * lr  # compute loss precisely
    history_loss_Unet.append(total_loss)

    if (iter+1) % display_step == 0:
        print ('[{:3d}/{}]: loss = {:.4f},  '.format(iter+1, num_iter, total_loss))
