In [51]:
from turtle import forward
import torch
from torch import nn, optim 
from torch.autograd import Variable 
import os 
from utils import * 
get_ipython().run_line_magic('matplotlib', 'inline')
import numpy as np 
import scipy.io as scio
import matplotlib.pyplot as plt 
import scipy.io
import math
import time
from math import exp
import torch.nn.functional as F

In [52]:
# para turn
torch.cuda.set_device(1)
torch.cuda.empty_cache()
#Case1: 0.2 G
#Case2: 0.3 G
case_folder = ['Case1']
data = ['scene10']
base = [30]

# tune for better performance
shrink = [26]
c_all = [6e-4]
max_iter = 1500
lr_real = 0.003

In [53]:
# class for calculating psnr and ssim
def torch_psnr(img, ref):  # input [28,256,256]
    img = (img*256).round()
    ref = (ref*256).round()
    nC = img.shape[0]
    psnr = 0
    for i in range(nC):
        mse = torch.mean((img[i, :, :] - ref[i, :, :]) ** 2)
        psnr += 10 * torch.log10((255*255)/mse)
    return psnr / nC

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

def torch_ssim(img, ref):  # input [28,256,256]
    return ssim(torch.unsqueeze(img, 0), torch.unsqueeze(ref, 0))

In [54]:
class diagonal_S_1(nn.Module):
    def __init__(self, n1,n2,n3):
        super(diagonal_S_1, self).__init__()
        self.n1 = n1
        self.n2 = n2
        self.n3 = n3
    
    def forward(self, s):
        tensor = torch.zeros(self.n3, self.n1, self.n2).cuda()
        for i in range(self.n3):
            tensor[i] = torch.diag(s[i])
        return tensor

class OTLinear_new(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.n_dim = shape
        self.num_householders = shape
        self.u = nn.Parameter(torch.randn(self.num_householders, self.n_dim))
        self.I = nn.Parameter(torch.eye(self.n_dim, self.n_dim).unsqueeze(0))
        self.I.requires_grad = False
    
    def get_weight(self, ):
        u = F.normalize(self.u, dim=1)
        w = self.I - 2 * u.unsqueeze(-1) @ u.unsqueeze(1)
        w = torch.chain_matmul(*[x.squeeze(0) for x in w.chunk(self.num_householders, dim=0)])
        return w
        
    def forward(self):
        return self.get_weight()
    
class orth_linear(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.n_dim = shape
        self.num_householders = shape
        self.u = nn.Parameter(torch.randn(self.num_householders, self.n_dim))
        self.I = nn.Parameter(torch.eye(self.n_dim, self.n_dim).unsqueeze(0))
        self.I.requires_grad = False

    def get_weight(self, ):
        u = F.normalize(self.u, dim=1)
        w = self.I - 2 * u.unsqueeze(-1) @ u.unsqueeze(1)
        w = torch.chain_matmul(*[x.squeeze(0) for x in w.chunk(self.num_householders, dim=0)])
        return w
    
    def forward(self, input):
        weight = self.get_weight()
        return torch.nn.functional.linear(input, weight)


class SVD_net(nn.Module): 
    def __init__(self,n_1,n_2,n_3,r,shape):
        super(SVD_net, self).__init__()
        # shrink = 25
        self.shape = shape
        self.A = nn.Parameter(torch.Tensor(n_1,r,n_3))
        self.s = nn.Parameter(torch.Tensor(r,n_3))
        self.B = nn.Parameter(torch.Tensor(n_2,r,n_3))
        
        self.net_t = nn.Sequential(
            permute_change(1,2,0),
            orth_linear(n_3),
            nn.Linear(n_3, n_3, bias=False)
        ) 
        self.reset_parameters()
        
        self.netA = nn.Sequential(
            orth_linear(n_3),
            permute_change(2,0,1)
        )
        self.netB = nn.Sequential(
            orth_linear(n_3),
            permute_change(2,0,1)
        )
        
        self.diagnet = diagonal_S_1(r, r, n_3)
        self.adaptive_thres_net = nn.Sequential(
            nn.Linear(r,r,bias = False),
            nn.LeakyReLU(),
            nn.Linear(r,r,bias = False)
        )

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.A.size(2))
        self.A.data.uniform_(-stdv, stdv)
        self.B.data.uniform_(-stdv, stdv)
        self.s.data.uniform_(-stdv, stdv)

    def forward(self):
        # print(self.nets(self.s).shape)
        A_hat = self.netA(self.A)
        B_hat_T = (self.netB(self.B)).permute(0,2,1)
        x = torch.matmul(torch.matmul(A_hat,self.diagnet(self.adaptive_thres_net( (self.s).t() ))), B_hat_T)
        return self.net_t(x), A_hat, B_hat_T

In [55]:
cur = -1
for cur_data in data:
    folder_path = "./results/denoising/ours/kaist/" + cur_data + "/"
    if not os.path.exists(folder_path):
        os.mkdir(folder_path)
        
    for cur_case in case_folder:
        cur = cur + 1
        
        file_name = './data/denoising/kaist/' + cur_data + '/' + cur_case + '/' + cur_data + '_noise.mat'
        mat = scipy.io.loadmat(file_name)
        X_np = mat["Nmsi"]
        X = torch.from_numpy(X_np).type(dtype).cuda()

        truth = torch.from_numpy(mat["Omsi"]).type(dtype).cuda()
        
        folder_path = "./results/denoising/ours/kaist/" + cur_data + "/" + cur_case + "/"
        if not os.path.exists(folder_path):
            os.mkdir(folder_path)
            
        for r in shrink:
            folder_path = "./results/denoising/ours/kaist/" + cur_data + "/" + cur_case + "/" + "r" + str(r) + "/"
            if not os.path.exists(folder_path):
                os.mkdir(folder_path)
            
            for gamma in c_all:
                
                folder_path = "./results/denoising/ours/kaist/" + cur_data + "/" + cur_case + "/" + "r" + str(r) + "/" + "gamma" + str(gamma) + "/"
                if not os.path.exists(folder_path):
                    os.mkdir(folder_path)
                
                model = SVD_net(256,256,28, r,256).type(dtype)
                mask = torch.ones(X.shape).type(dtype)
                mask[X == 0] = 0
                X[mask == 0] = 0
                
                params = []
                params += [x for x in model.parameters()]
                s = sum([np.prod(list(p.size())) for p in params]);
                print('Number of params: %d' % s)
                optimizier = optim.Adam(params, lr=lr_real, weight_decay=10e-6)
                
                t0 = time.time()
                # show = [0,5,9]
                for iter in range(max_iter):
                    X_Out_real, A_hat, B_hat = model()
                    
                    loss_f = 0
                    loss_A = 0
                    loss_H_k = 0
                    loss_B = 0
                    loss_total = 0
                    
                    loss_f = 10e-6*torch.norm(X_Out_real*mask-X*mask,1)
                    # A_hat = params[0]
                    loss_A = gamma*torch.norm(A_hat[:,1:,:]-A_hat[:,:-1,:],1)
                    # B_hat = params[2]
                    loss_B = gamma*torch.norm(B_hat[:,:,1:]-B_hat[:,:,:-1],1)
                    H_k = params[5]
                    # print(H_k.shape)
                    loss_H_k = gamma*torch.norm(H_k[1:,:]-H_k[:-1,:],1)
                    
                    optimizier.zero_grad()
                    loss_total = loss_f + loss_H_k + loss_A + loss_B
                    loss_total.backward(retain_graph=False)
                    optimizier.step()
                
                    if (iter+1) % 250 == 0:
                        img = X_Out_real.permute(2, 0, 1)
                        gt = truth.permute(2,0,1)
                        psnr = torch_psnr(img, gt)
                        print("r:" + str(r) + ",gamma:" + str(gamma) + ",iter:" + str(iter) + ",psnr:" + str(psnr))

                        if psnr.item() >= base[cur]:
                            scio.savemat(folder_path + str(iter) + "_" + str(psnr.item()) + "_" + cur_data + "_result.mat", {cur_data + '_result':X_Out_real.cpu().detach().numpy()})
                        if psnr.item() >= base[cur] - 1:
                            s = "data:" + cur_data + ",Case:" + cur_case + ",r:" + str(r) + ",gamma:" + str(gamma) + ",iter:" + str(iter) + ",psnr:" + str(psnr.item()) + '\n'
                            with open('./results/denoising/ours/kaist/'+ cur_data + '/' + cur_case + "/" + '/psnr.txt', 'a', encoding='utf-8') as f:
                                f.write(s)
                            print(s)
                t1 = time.time()
                print(t1-t0)

Number of params: 380304
r:26,gamma:0.0006,iter:249,psnr:tensor(27.5629, device='cuda:1', grad_fn=<DivBackward0>)
r:26,gamma:0.0006,iter:499,psnr:tensor(28.8396, device='cuda:1', grad_fn=<DivBackward0>)
r:26,gamma:0.0006,iter:749,psnr:tensor(29.8096, device='cuda:1', grad_fn=<DivBackward0>)
data:scene10,Case:Case1,r:26,gamma:0.0006,iter:749,psnr:29.809642791748047

r:26,gamma:0.0006,iter:999,psnr:tensor(30.4591, device='cuda:1', grad_fn=<DivBackward0>)
data:scene10,Case:Case1,r:26,gamma:0.0006,iter:999,psnr:30.4591121673584

r:26,gamma:0.0006,iter:1249,psnr:tensor(30.6035, device='cuda:1', grad_fn=<DivBackward0>)
data:scene10,Case:Case1,r:26,gamma:0.0006,iter:1249,psnr:30.603479385375977

r:26,gamma:0.0006,iter:1499,psnr:tensor(30.5731, device='cuda:1', grad_fn=<DivBackward0>)
data:scene10,Case:Case1,r:26,gamma:0.0006,iter:1499,psnr:30.573068618774414

14.71748948097229
