In [11]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 22 20:18:19 2021

@author: xuquanfeng
"""
from PIL import Image
import torch
from torchvision import datasets,transforms,utils,models
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import random
import os
import datetime
import torchvision
import torch.nn.functional as F
from astropy.io import fits
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from torch import optim
#设置随机种子
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True
    os.environ['PYTHONHASHSEED'] = str(seed)
setup_seed(10)
# Hyper parameters
if not os.path.exists('./model'):
    os.mkdir('./model')
if not os.path.exists('./train_proces'):
    os.mkdir('./train_proces')
num_epochs = 20   #循环次数
batch_size = 128    #每次投喂数据量
learning_rate = 0.00001   #学习率
num_var = 40
momentum = 0.8
k = 1

class MyDataset(torch.utils.data.Dataset): 
    def __init__(self, datatxt, transform=None, target_transform=None):
        super(MyDataset, self).__init__()
        fh = open(datatxt, 'r')
        imgs = []
        for line in fh:
            words = line.rstrip().split()
            imgs.append((words[0],words[1]))
        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):  # 这个方法是必须要有的，用于按照索引读取每个元素的具体内容
        fn = self.imgs[index]  # fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
        hdu = fits.open(fn[0])
        img = hdu[0].data
        img = np.array(img,dtype=np.float32)
        hdu.close()
        hdu = fits.open(fn[1])
        oimg = hdu[0].data
        oimg = np.array(oimg,dtype=np.float32)
        hdu.close()
        if self.transform is not None:
            img = self.transform(img)
            img = img.permute(1,0,2)
            oimg = self.transform(oimg)
            oimg = oimg.permute(1,0,2)
        return img,oimg,fn
    def __len__(self):  # 这个函数也必须要写，它返回的是数据集的长度，也就是多少张图片，要和loader的长度作区分
        return len(self.imgs)


print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

class VAE(nn.Module):
    def __init__(self,num_var):
        super(VAE, self).__init__()
        modules = []
        hidden_dims = [32, 64, 128]
        self.hidden_dims = hidden_dims
        in_channels = 3
        latent_dim = num_var
        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    # nn.BatchNorm2d(h_dim),
                    nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2),return_indices=True),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*16, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*16, latent_dim)
        
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 16)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.MaxUnpool2d((2, 2), stride=(2, 2)),
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    # nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Sequential(
                            nn.MaxUnpool2d((2, 2), stride=(2, 2)),
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            # nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.ReLU())

    def encode(self, x):
        result = x
        idx = []
        for i in range(len(self.hidden_dims)):
            result,indices = self.encoder[i][:2](result)
            idx.append(indices)
            result = self.encoder[i][2](result)        
        self.idx = idx
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)
        return mu, log_var

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        if torch.cuda.is_available():
            eps = torch.FloatTensor(std.size()).normal_().to(device)
        else:
            eps = torch.FloatTensor(std.size()).normal_()
        eps = Variable(eps)
        return eps.mul(std).add_(mu)

    def decode(self, z):
        result = self.decoder_input(z)
        result = result.view(len(result), 128, 4, 4)
        for i in range(len(self.hidden_dims)-1):
            result = self.decoder[i][0](result,self.idx[len(self.hidden_dims)-1-i])
            result = self.decoder[i][1:](result)
        # result = self.decoder(result)
        result = self.final_layer[0](result,self.idx[0])
        result = self.final_layer[1:](result)
        return result

    def forward(self, x):
        # x = x.view(len(x),-1)
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        return self.decode(z), mu, logvar
torch.cuda.empty_cache()
model = torch.load('/data/xqf/VAE2/model/vae_40_best.pth')

# print(model)
# Device configuration  判断能否使用cuda加速
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


PyTorch Version:  1.10.0+cu113
Torchvision Version:  0.11.1+cu113


In [12]:
def _fspecial_gauss_1d(size, sigma):
    coords = torch.arange(size).to(dtype=torch.float)
    coords -= size//2
    g = torch.exp(-(coords**2) / (2*sigma**2))
    g /= g.sum()
    return g.unsqueeze(0).unsqueeze(0)
    
def gaussian_filter(input, win):
    N, C, H, W = input.shape
    out = F.conv2d(input, win, stride=1, padding=0, groups=C)
    out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C)
    return out

def _ssim(X, Y, win, data_range=1023, size_average=True, full=False):
    K1 = 0.01
    K2 = 0.03
    batch, channel, height, width = X.shape
    compensation = 1.0

    C1 = (K1 * data_range)**2
    C2 = (K2 * data_range)**2

    win = win.to(X.device, dtype=X.dtype)

    mu1 = gaussian_filter(X, win)
    mu2 = gaussian_filter(Y, win)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = compensation * ( gaussian_filter(X * X, win) - mu1_sq )
    sigma2_sq = compensation * ( gaussian_filter(Y * Y, win) - mu2_sq )
    sigma12   = compensation * ( gaussian_filter(X * Y, win) - mu1_mu2 )

    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map

    if size_average:
        ssim_val = ssim_map.mean()
        cs = cs_map.mean()
    else:
        ssim_val = ssim_map.mean(-1).mean(-1).mean(-1)  # reduce along CHW
        cs = cs_map.mean(-1).mean(-1).mean(-1)

    if full:
        return ssim_val, cs
    else:
        return ssim_val

def ssim(X, Y, win_size=11, win_sigma=10, win=None, data_range=1, size_average=True, full=False):

    if len(X.shape) != 4:
        raise ValueError('Input images must 4-d tensor.')

    if not X.type() == Y.type():
        raise ValueError('Input images must have the same dtype.')

    if not X.shape == Y.shape:
        raise ValueError('Input images must have the same dimensions.')

    if not (win_size % 2 == 1):
        raise ValueError('Window size must be odd.')

    win_sigma = win_sigma
    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat(X.shape[1], 1, 1, 1)
    else:
        win_size = win.shape[-1]

    ssim_val, cs = _ssim(X, Y,
                         win=win,
                         data_range=data_range,
                         size_average=False,
                         full=True)
    if size_average:
        ssim_val = ssim_val.mean()
        cs = cs.mean()

    if full:
        return ssim_val, cs
    else:
        return ssim_val


    def __init__(self, win_size=11, win_sigma=1.5, data_range=255, size_average=True, channel=3, weights=None):
        super(MS_SSIM, self).__init__()
        self.win = _fspecial_gauss_1d(
            win_size, win_sigma).repeat(channel, 1, 1, 1)
        self.size_average = size_average
        self.data_range = data_range
        self.weights = weights

    def forward(self, X, Y):
        return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range, weights=self.weights)


In [13]:
train_data = MyDataset(datatxt='train_tal.txt', transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size = batch_size, shuffle=False,num_workers=20)

if not os.path.exists('./result'):
    os.mkdir('./result')
model.eval()
from tqdm import tqdm
sssi = []
with torch.no_grad():
    for batch_idx, data in enumerate(tqdm(train_loader)):
        img,oimg,fn = data
        img = Variable(img)
        img = img.to(device)
        oimg = Variable(oimg)
        oimg = oimg.to(device)
        
        cimg, mu, lov = model(img)
        ocimg, omu, olov = model(oimg)

        for i in range(len(img)):
            ssim_val = ssim(img[i].unsqueeze(0), cimg[i].unsqueeze(0), data_range=1, size_average=True,)
            qw = [fn[0][i]]
            qw.append(ssim_val.cpu().detach().numpy())
            qw.extend(mu[i].cpu().detach().numpy())
            qw.append(fn[1][i])
            ssim_val = ssim(oimg[i].unsqueeze(0), ocimg[i].unsqueeze(0), data_range=1, size_average=True,)
            qw.append(ssim_val.cpu().detach().numpy())
            qw.extend(omu[i].cpu().detach().numpy())
            sssi.append(qw)

    dd = np.array(sssi)
    print(len(dd))
    # np.save(pt+'result_ssim.npy',dd)
    np.save('./result/notf_resu_'+str(num_var)+'_'+str(k)+'_all.npy',dd)

100%|██████████| 27/27 [00:14<00:00,  1.92it/s]


3434
