In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 22 20:18:19 2021

@author: xuquanfeng
"""
from PIL import Image
import torch
from torchvision import datasets,transforms,utils,models
from VAE_model.models import VAE, MyDataset
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import random
import os
import datetime
import torchvision
import torch.nn.functional as F
from astropy.io import fits
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from torch import optim
#设置随机种子
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True
    os.environ['PYTHONHASHSEED'] = str(seed)
setup_seed(10)
# Hyper parameters
if not os.path.exists('./model'):
    os.mkdir('./model')
if not os.path.exists('./train_proces'):
    os.mkdir('./train_proces')
num_epochs = 20   #循环次数
batch_size = 128    #每次投喂数据量
learning_rate = 0.00001   #学习率
num_var = 40
momentum = 0.8
k = 1

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

torch.cuda.empty_cache()
model = torch.load('/data/xqf/VAE3/model/vae_40_best.pth')

# print(model)
# Device configuration  判断能否使用cuda加速
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

reconstruction_function = nn.MSELoss(size_average=False)

def loss_function1(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return BCE + k*KLD


PyTorch Version:  1.10.0+cu113
Torchvision Version:  0.11.1+cu113




In [3]:
class MMDLoss(nn.Module):
    '''
    计算源域数据和目标域数据的MMD距离
    Params:
    source: 源域数据（n * len(x))
    target: 目标域数据（m * len(y))
    kernel_mul:
    kernel_num: 取不同高斯核的数量
    fix_sigma: 不同高斯核的sigma值
    Return:
    loss: MMD loss
    '''
    def __init__(self, kernel_type='rbf', kernel_mul=2.0, kernel_num=5, fix_sigma=None, **kwargs):
        super(MMDLoss, self).__init__()
        self.kernel_num = kernel_num
        self.kernel_mul = kernel_mul
        self.fix_sigma = None
        self.kernel_type = kernel_type

    def guassian_kernel(self, source, target, kernel_mul, kernel_num, fix_sigma):
        n_samples = int(source.size()[0]) + int(target.size()[0])
        total = torch.cat([source, target], dim=0)
        total0 = total.unsqueeze(0).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        total1 = total.unsqueeze(1).expand(
            int(total.size(0)), int(total.size(0)), int(total.size(1)))
        L2_distance = ((total0-total1)**2).sum(2)
        if fix_sigma:
            bandwidth = fix_sigma
        else:
            bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
        bandwidth /= kernel_mul ** (kernel_num // 2)
        bandwidth_list = [bandwidth * (kernel_mul**i)
                          for i in range(kernel_num)]
        kernel_val = [torch.exp(-L2_distance / bandwidth_temp)
                      for bandwidth_temp in bandwidth_list]
        return sum(kernel_val)

    def linear_mmd2(self, f_of_X, f_of_Y):
        loss = 0.0
        delta = f_of_X.float().mean(0) - f_of_Y.float().mean(0)
        loss = delta.dot(delta.T)
        return loss

    def forward(self, source, target):
        if self.kernel_type == 'linear':
            return self.linear_mmd2(source, target)
        elif self.kernel_type == 'rbf':
            batch_size = int(source.size()[0])
            kernels = self.guassian_kernel(
                source, target, kernel_mul=self.kernel_mul, kernel_num=self.kernel_num, fix_sigma=self.fix_sigma)
            XX = torch.mean(kernels[:batch_size, :batch_size])
            YY = torch.mean(kernels[batch_size:, batch_size:])
            XY = torch.mean(kernels[:batch_size, batch_size:])
            YX = torch.mean(kernels[batch_size:, :batch_size])
            loss = torch.mean(XX + YY - XY - YX)
            return loss
def loss_function(recon_x, x, mu, logvar,recon_x1, x1, mu1, logvar1):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    MMD = MMDLoss()
    l = loss_function1(recon_x, x, mu, logvar)
    l1 = loss_function1(recon_x1, x1, mu1, logvar1)    
    return l+l1+0.5*MMD(source=mu, target=mu1)

In [4]:
train_loss11 = open('./train_proces/train_1.txt', 'w')
train_data = MyDataset(datatxt='train_tal1.txt', transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size = batch_size, shuffle=True,num_workers=20)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)



strattime = datetime.datetime.now()
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        img,oimg,_ = data
        img = Variable(img)
        img = img.to(device)
        oimg = Variable(oimg)
        oimg = oimg.to(device)
        optimizer.zero_grad()
        cimg, mu, lov = model(img)
        ocimg, omu, olov = model(oimg)
        loss = loss_function(cimg, img, mu, lov, ocimg, oimg, omu, olov)
        loss.backward()
        # train_loss += loss.data[0]
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            endtime = datetime.datetime.now()
            asd = str('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} time:{:.2f}s'.format(
                epoch,
                batch_idx * len(img),
                len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(img),
                (endtime-strattime).seconds))
            print(asd)
            train_loss11.write(asd+'\n')
            # torch.save(model, './model/b_vae'+str(epoch)+'_'+str(batch_idx)+'.pth')
    if epoch == 0:
        best_loss = train_loss / len(train_loader.dataset)
    if epoch > 0 and best_loss > train_loss / len(train_loader.dataset):
        best_loss = train_loss / len(train_loader.dataset)
        asds = 'Save Best Model!'
        print(asds)
        train_loss11.write(asds+'\n')
        torch.save(model, './model/vae_'+str(num_var)+'_'+str(k)+'_best.pth')
    asds = str('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
    print(asds)
    train_loss11.write(asds+'\n')
train_loss11.close()

====> Epoch: 0 Average loss: 7271.3367
Save Best Model!
====> Epoch: 1 Average loss: 5447.1001
Save Best Model!
====> Epoch: 2 Average loss: 4489.5600
Save Best Model!
====> Epoch: 3 Average loss: 3892.6618
Save Best Model!
====> Epoch: 4 Average loss: 3514.3652
Save Best Model!
====> Epoch: 5 Average loss: 3258.8613
Save Best Model!
====> Epoch: 6 Average loss: 3074.4041
Save Best Model!
====> Epoch: 7 Average loss: 2945.1553
Save Best Model!
====> Epoch: 8 Average loss: 2847.8010
Save Best Model!
====> Epoch: 9 Average loss: 2769.1546
Save Best Model!
====> Epoch: 10 Average loss: 2702.9906
Save Best Model!
====> Epoch: 11 Average loss: 2644.8528
Save Best Model!
====> Epoch: 12 Average loss: 2595.2067
Save Best Model!
====> Epoch: 13 Average loss: 2551.0676
Save Best Model!
====> Epoch: 14 Average loss: 2512.3166
Save Best Model!
====> Epoch: 15 Average loss: 2477.4834
Save Best Model!
====> Epoch: 16 Average loss: 2445.3356
Save Best Model!
====> Epoch: 17 Average loss: 2418.0844
S

In [5]:
def _fspecial_gauss_1d(size, sigma):
    coords = torch.arange(size).to(dtype=torch.float)
    coords -= size//2
    g = torch.exp(-(coords**2) / (2*sigma**2))
    g /= g.sum()
    return g.unsqueeze(0).unsqueeze(0)
    
def gaussian_filter(input, win):
    N, C, H, W = input.shape
    out = F.conv2d(input, win, stride=1, padding=0, groups=C)
    out = F.conv2d(out, win.transpose(2, 3), stride=1, padding=0, groups=C)
    return out

def _ssim(X, Y, win, data_range=1023, size_average=True, full=False):
    K1 = 0.01
    K2 = 0.03
    batch, channel, height, width = X.shape
    compensation = 1.0

    C1 = (K1 * data_range)**2
    C2 = (K2 * data_range)**2

    win = win.to(X.device, dtype=X.dtype)

    mu1 = gaussian_filter(X, win)
    mu2 = gaussian_filter(Y, win)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = compensation * ( gaussian_filter(X * X, win) - mu1_sq )
    sigma2_sq = compensation * ( gaussian_filter(Y * Y, win) - mu2_sq )
    sigma12   = compensation * ( gaussian_filter(X * Y, win) - mu1_mu2 )

    cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)
    ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map

    if size_average:
        ssim_val = ssim_map.mean()
        cs = cs_map.mean()
    else:
        ssim_val = ssim_map.mean(-1).mean(-1).mean(-1)  # reduce along CHW
        cs = cs_map.mean(-1).mean(-1).mean(-1)

    if full:
        return ssim_val, cs
    else:
        return ssim_val

def ssim(X, Y, win_size=11, win_sigma=10, win=None, data_range=1, size_average=True, full=False):

    if len(X.shape) != 4:
        raise ValueError('Input images must 4-d tensor.')

    if not X.type() == Y.type():
        raise ValueError('Input images must have the same dtype.')

    if not X.shape == Y.shape:
        raise ValueError('Input images must have the same dimensions.')

    if not (win_size % 2 == 1):
        raise ValueError('Window size must be odd.')

    win_sigma = win_sigma
    if win is None:
        win = _fspecial_gauss_1d(win_size, win_sigma)
        win = win.repeat(X.shape[1], 1, 1, 1)
    else:
        win_size = win.shape[-1]

    ssim_val, cs = _ssim(X, Y,
                         win=win,
                         data_range=data_range,
                         size_average=False,
                         full=True)
    if size_average:
        ssim_val = ssim_val.mean()
        cs = cs.mean()

    if full:
        return ssim_val, cs
    else:
        return ssim_val


    def __init__(self, win_size=11, win_sigma=1.5, data_range=255, size_average=True, channel=3, weights=None):
        super(MS_SSIM, self).__init__()
        self.win = _fspecial_gauss_1d(
            win_size, win_sigma).repeat(channel, 1, 1, 1)
        self.size_average = size_average
        self.data_range = data_range
        self.weights = weights

    def forward(self, X, Y):
        return ms_ssim(X, Y, win=self.win, size_average=self.size_average, data_range=self.data_range, weights=self.weights)


In [7]:
train_data = MyDataset(datatxt='train_tal1.txt', transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size = batch_size, shuffle=False,num_workers=20)

if not os.path.exists('./result'):
    os.mkdir('./result')
model.eval()
from tqdm import tqdm
sssi = []
with torch.no_grad():
    for batch_idx, data in enumerate(tqdm(train_loader)):
        img,oimg,fn = data
        img = Variable(img)
        img = img.to(device)
        oimg = Variable(oimg)
        oimg = oimg.to(device)
        optimizer.zero_grad()
        
        cimg, mu, lov = model(img)
        ocimg, omu, olov = model(oimg)

        for i in range(len(img)):
            ssim_val = ssim(img[i].unsqueeze(0), cimg[i].unsqueeze(0), data_range=1, size_average=True,)
            qw = [fn[0][i]]
            qw.append(ssim_val.cpu().detach().numpy())
            qw.extend(mu[i].cpu().detach().numpy())
            qw.append(fn[1][i])
            ssim_val = ssim(oimg[i].unsqueeze(0), ocimg[i].unsqueeze(0), data_range=1, size_average=True,)
            qw.append(ssim_val.cpu().detach().numpy())
            qw.extend(omu[i].cpu().detach().numpy())
            sssi.append(qw)

    dd = np.array(sssi)
    print(len(dd))
    # np.save(pt+'result_ssim.npy',dd)
    np.save('./result/resu_'+str(num_var)+'_'+str(k)+'_all_10_15.npy',dd)

100%|███████████████████████████████████████████| 29/29 [00:13<00:00,  2.15it/s]


3618
