In [11]:
from turtle import forward
import torch
from torch import nn, optim 
from torch.autograd import Variable 
import os 
from utils import * 
get_ipython().run_line_magic('matplotlib', 'inline')
import numpy as np 
import scipy.io as scio
import matplotlib.pyplot as plt 
import scipy.io
import math
import time
from math import exp
import torch.nn.functional as F

In [12]:
torch.cuda.set_device(0)
torch.cuda.empty_cache()

# tune for best performance
shrink = [20]
c_all = [1e-6]

max_iter = 12000
lr_real = 0.001

In [13]:
# class for calculating psnr and ssim
def torch_psnr(img, ref):  # input [28,256,256]
    img = (img*256).round()
    ref = (ref*256).round()
    nC = img.shape[0]
    psnr = 0
    for i in range(nC):
        mse = torch.mean((img[i, :, :] - ref[i, :, :]) ** 2)
        psnr += 10 * torch.log10((255*255)/mse)
    return psnr / nC

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()

def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)

def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)

def torch_ssim(img, ref):  # input [28,256,256]
    return ssim(torch.unsqueeze(img, 0), torch.unsqueeze(ref, 0))

In [14]:

class diagonal_S_1(nn.Module):
    def __init__(self, n1,n2,n3):
        super(diagonal_S_1, self).__init__()
        self.n1 = n1
        self.n2 = n2
        self.n3 = n3
    
    def forward(self, s):
        tensor = torch.zeros(self.n3, self.n1, self.n2).cuda()
        for i in range(self.n3):
            tensor[i] = torch.diag(s[i])
        return tensor

class OTLinear_new(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.n_dim = shape
        self.num_householders = shape
        self.u = nn.Parameter(torch.randn(self.num_householders, self.n_dim))
        self.I = nn.Parameter(torch.eye(self.n_dim, self.n_dim).unsqueeze(0))
        self.I.requires_grad = False
    
    def get_weight(self, ):
        u = F.normalize(self.u, dim=1)
        w = self.I - 2 * u.unsqueeze(-1) @ u.unsqueeze(1)
        w = torch.chain_matmul(*[x.squeeze(0) for x in w.chunk(self.num_householders, dim=0)])
        return w
        
    def forward(self):
        return self.get_weight()
    
class orth_linear(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.n_dim = shape
        self.num_householders = shape
        self.u = nn.Parameter(torch.randn(self.num_householders, self.n_dim))
        self.I = nn.Parameter(torch.eye(self.n_dim, self.n_dim).unsqueeze(0))
        self.I.requires_grad = False

    def get_weight(self, ):
        u = F.normalize(self.u, dim=1)
        w = self.I - 2 * u.unsqueeze(-1) @ u.unsqueeze(1)
        w = torch.chain_matmul(*[x.squeeze(0) for x in w.chunk(self.num_householders, dim=0)])
        return w
    
    def forward(self, input):
        weight = self.get_weight()
        return torch.nn.functional.linear(input, weight)
        # h, w, _ = input.shape
        # input = input.view(h * w, self.n_dim)
        # return (torch.matmul(input, weight)).view(h,w,self.n_dim)


class SVD_net(nn.Module): 
    def __init__(self,n_1,n_2,n_3,shrink,shape):
        super(SVD_net, self).__init__()
        # shrink = 25
        self.shape = shape
        self.A = nn.Parameter(torch.Tensor(n_1,n_2//shrink,n_3))
        self.s = nn.Parameter(torch.Tensor(n_2//shrink,n_3))
        self.B = nn.Parameter(torch.Tensor(n_2,n_2//shrink,n_3))
        
        self.net_t = nn.Sequential(
            permute_change(1,2,0),
            orth_linear(n_3),
            nn.Linear(n_3, n_3, bias=False)
        ) 
        self.reset_parameters()
        
        self.netA = nn.Sequential(
            orth_linear(n_3),
            permute_change(2,0,1)
        )
        self.netB = nn.Sequential(
            orth_linear(n_3),
            permute_change(2,0,1)
        )
        
        self.diagnet = diagonal_S_1(n_2//shrink, n_2//shrink, n_3)
        self.adaptive_thres_net = nn.Sequential(
            nn.Linear(n_2//shrink,n_2//shrink,bias = False),
            nn.LeakyReLU(),
            nn.Linear(n_2//shrink,n_2//shrink,bias = False)
        )

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.A.size(2))
        self.A.data.uniform_(-stdv, stdv)
        self.B.data.uniform_(-stdv, stdv)
        self.s.data.uniform_(-stdv, stdv)

    def forward(self):
        A_hat = self.netA(self.A)
        B_hat_T = (self.netB(self.B)).permute(0,2,1)
        x = torch.matmul(torch.matmul(A_hat,self.diagnet(self.adaptive_thres_net( (self.s).t() ))), B_hat_T)
        return self.net_t(x), A_hat, B_hat_T

In [15]:
truth_np = scipy.io.loadmat("./data/snapshot/truth/scene04.mat")
truth = torch.from_numpy(truth_np['img']).cuda()

# load measurement and 3d-shift-mask
file_name = "./data/snapshot/measurement_mask/scene04_mask_And_meas.mat"
mat = scipy.io.loadmat(file_name)
X_np = mat["Nhsi"]
X = torch.from_numpy(X_np).type(dtype).cuda() # measurement
shift_3d_mask_np = mat["mask"]
shift_3d_mask = torch.from_numpy(shift_3d_mask_np).type(dtype).cuda() # 3d-shift-mask

for r in shrink:
    folder_path = "./results/snapshot/ours/scene04/" + "r" + str(r) + "/"
    if not os.path.exists(folder_path):
        os.mkdir(folder_path)
    
    for gamma in c_all:
        F_norm = nn.MSELoss()
        
        model = SVD_net(256,256,28, r,256).type(dtype)
        
        params = []
        params += [x for x in model.parameters()]
        
        s = sum([np.prod(list(p.size())) for p in params]);
        print('Number of params: %d' % s)
        optimizier = optim.Adam(params, lr=lr_real, weight_decay=10e-8)
        
        t0 = time.time()
        for iter in range(max_iter):
            X_Out_real, A_hat, B_hat = model()
            out_shift = np.zeros((256, 256 + 2 * (28 - 1), 28))
            out_shift = torch.from_numpy(out_shift).type(dtype).cuda()
            for i in range(28):
                out_shift[:,i*2:i*2+256,i]=X_Out_real[:,:,i]

            loss_f = 0
            loss_A = 0
            loss_H_k = 0
            loss_B = 0
            loss_total = 0
            
            loss_f = F_norm(torch.sum(out_shift*shift_3d_mask,dim = 2),X)
            loss_A = gamma*torch.norm(A_hat[:,1:,:]-A_hat[:,:-1,:],1)
            loss_B = gamma*torch.norm(B_hat[:,:,1:]-B_hat[:,:,:-1],1)
            H_k = params[5]
            loss_H_k = gamma*torch.norm(H_k[1:,:]-H_k[:-1,:],1)
            
            optimizier.zero_grad()
            loss_total = loss_f + loss_A + loss_H_k + loss_B
            loss_total.backward(retain_graph=False)
            optimizier.step()
        
            if (iter+1) % 1000 == 0:
                folder_path = "./results/snapshot/ours/scene04/" + "r" + str(r) + "/" + "gamma" + str(gamma) + "/"
                if not os.path.exists(folder_path):
                    os.mkdir(folder_path)
                
                img = X_Out_real.permute(2, 0, 1)
                gt = truth.permute(2,0,1)
                psnr = torch_psnr(img, gt)
                if psnr.item() >= 44.0:
                    scio.savemat(folder_path + str(iter) + "_" + str(psnr.item()) + "_scene04_result.mat", {'scene04_result':X_Out_real.cpu().detach().numpy()})
                s = "r:" + str(r) + ",gamma:" + str(gamma) + ",iter:" + str(iter) + ",psnr:" + str(psnr.item()) + '\n'
                with open('./results/snapshot/ours/scene04/psnr.txt', 'a', encoding='utf-8') as f:
                    f.write(s)
                print(s)

Number of params: 178144
r:20,gamma:1e-06,iter:999,psnr:38.76161193847656

r:20,gamma:1e-06,iter:1999,psnr:42.30829620361328

r:20,gamma:1e-06,iter:2999,psnr:43.29865264892578

r:20,gamma:1e-06,iter:3999,psnr:43.58987045288086

r:20,gamma:1e-06,iter:4999,psnr:43.684505462646484

r:20,gamma:1e-06,iter:5999,psnr:43.72355270385742

r:20,gamma:1e-06,iter:6999,psnr:43.774356842041016

r:20,gamma:1e-06,iter:7999,psnr:43.83332824707031

r:20,gamma:1e-06,iter:8999,psnr:43.87248611450195



KeyboardInterrupt: 