In [8]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 22 20:18:19 2021

@author: xuquanfeng
"""
from PIL import Image
import torch
from torchvision import datasets,transforms,utils,models
from VAE_model.models import VAE, MyDataset
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import random
import os
import datetime
import torchvision
import torch.nn.functional as F
from astropy.io import fits
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from torch import optim
#设置随机种子
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True
    os.environ['PYTHONHASHSEED'] = str(seed)
setup_seed(10)
# Hyper parameters
if not os.path.exists('./model'):
    os.mkdir('./model')
if not os.path.exists('./train_proces'):
    os.mkdir('./train_proces')
num_epochs = 20   #循环次数
batch_size = 128    #每次投喂数据量
learning_rate = 0.00001   #学习率
num_var = 40
momentum = 0.8
k = 1

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

torch.cuda.empty_cache()
model = torch.load('/data/xqf/VAE2/model/vae_40_best_1.pth')

# print(model)
# Device configuration  判断能否使用cuda加速
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

reconstruction_function = nn.MSELoss(size_average=False)

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return BCE + k*KLD


PyTorch Version:  1.10.0+cu113
Torchvision Version:  0.11.1+cu113


In [2]:
train_loss11 = open('./train_proces/train_'+str(num_var)+'_'+str(k)+'.txt', 'w')
train_data = MyDataset(datatxt='train_t.txt', transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size = batch_size, shuffle=True,num_workers=20)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
strattime = datetime.datetime.now()
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        img,fn = data
        img = Variable(img)
        img = img.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(img)
        loss = loss_function(recon_batch, img, mu, logvar)
        loss.backward()
        # train_loss += loss.data[0]
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            endtime = datetime.datetime.now()
            asd = str('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} time:{:.2f}s'.format(
                epoch,
                batch_idx * len(img),
                len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(img),
                (endtime-strattime).seconds))
            print(asd)
            train_loss11.write(asd+'\n')
            # torch.save(model, './model/b_vae'+str(epoch)+'_'+str(batch_idx)+'.pth')
    if epoch == 0:
        best_loss = train_loss / len(train_loader.dataset)
    if epoch > 0 and best_loss > train_loss / len(train_loader.dataset):
        best_loss = train_loss / len(train_loader.dataset)
        asds = 'Save Best Model!'
        print(asds)
        train_loss11.write(asds+'\n')
        torch.save(model, './model/vae_'+str(num_var)+'_'+str(k)+'_best.pth')
    asds = str('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
    print(asds)
    train_loss11.write(asds+'\n')
train_loss11.close()
# if epoch == num_epochs-1:
#     torch.save(model, './model/vae_'+str(num_var)+'_'+str(k)+'.pth')


====> Epoch: 0 Average loss: 3671.7266
Save Best Model!
====> Epoch: 1 Average loss: 3288.7866
Save Best Model!
====> Epoch: 2 Average loss: 3156.9852
Save Best Model!
====> Epoch: 3 Average loss: 2988.3582
Save Best Model!
====> Epoch: 4 Average loss: 2895.0506
Save Best Model!
====> Epoch: 5 Average loss: 2805.1460
Save Best Model!
====> Epoch: 6 Average loss: 2791.3678
Save Best Model!
====> Epoch: 7 Average loss: 2779.9828
Save Best Model!
====> Epoch: 8 Average loss: 2735.4346
Save Best Model!
====> Epoch: 9 Average loss: 2722.6169
Save Best Model!
====> Epoch: 10 Average loss: 2694.9898
Save Best Model!
====> Epoch: 11 Average loss: 2624.9683
Save Best Model!
====> Epoch: 12 Average loss: 2589.9671
Save Best Model!
====> Epoch: 13 Average loss: 2516.1009
====> Epoch: 14 Average loss: 2520.6284
Save Best Model!
====> Epoch: 15 Average loss: 2496.1343
Save Best Model!
====> Epoch: 16 Average loss: 2486.8676
Save Best Model!
====> Epoch: 17 Average loss: 2462.4820
Save Best Model!
=

In [9]:
train_data = MyDataset(datatxt='train_t1.txt', transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size = batch_size, shuffle=False,num_workers=20)

if not os.path.exists('./result'):
    os.mkdir('./result')
model.eval()
from tqdm import tqdm
sssi = []
with torch.no_grad():
    for batch_idx, data in enumerate(tqdm(train_loader)):
        img,fn = data
        img = Variable(img)
        img = img.to(device)
        mu, logvar = model.encode(img)

        yunatu = img.cpu().numpy()
        for i in range(len(yunatu)):
            qw = [fn[i]]
            qw.extend(mu[i].cpu().detach().numpy())
            sssi.append(qw)

    dd = np.array(sssi)
    print(len(dd))
    # np.save(pt+'result_ssim.npy',dd)
    np.save('./result/resu_'+str(num_var)+'_'+str(k)+'_in_no.npy',dd)

100%|██████████| 27/27 [00:06<00:00,  4.40it/s]


3434


In [4]:
print(dd.shape)
print(dd)

(3434, 41)
[['/data/GZ_Decals/MGS_out_DECaLS/175.89149808110213_21.67228737779398_0.262_grz_.fits'
  '-0.2395398' '0.30021068' ... '0.03285286' '-0.15634988' '0.2659969']
 ['/data/GZ_Decals/MGS_out_DECaLS/176.1464259029576_6.42738628465911_0.262_grz_.fits'
  '0.06613478' '-0.045190185' ... '0.25628704' '-0.060526717'
  '-0.6036183']
 ['/data/GZ_Decals/MGS_out_DECaLS/216.73248662726843_-2.558336707515353_0.262_grz_.fits'
  '5.0170693' '3.13536' ... '-1.669854' '-0.7415669' '3.479699']
 ...
 ['/data/GZ_Decals/MGS_out_DECaLS/185.8959342163261_12.898011042802942_0.262_grz_.fits'
  '4.595039' '0.9067448' ... '-2.7144554' '-8.880546' '-2.725317']
 ['/data/GZ_Decals/MGS_out_DECaLS/130.09162761307095_27.113144869991118_0.262_grz_.fits'
  '0.1841171' '-0.02502913' ... '0.4120853' '0.27991506' '0.40033835']
 ['/data/GZ_Decals/MGS_out_DECaLS/209.69722323971683_-2.796119143808119_0.262_grz_.fits'
  '0.89246887' '-2.2485366' ... '-0.73100126' '0.91381544' '1.7066566']]


In [7]:
print(dd.shape)
print(dd)

(3434, 41)
[['/data/GZ_Decals/MGS_out_DECaLS/175.89149808110213_21.67228737779398_0.262_grz_.fits'
  '-0.28436983' '0.30952153' ... '0.030438695' '-0.17496628' '0.2859994']
 ['/data/GZ_Decals/MGS_out_DECaLS/176.1464259029576_6.42738628465911_0.262_grz_.fits'
  '-0.04641933' '0.2062142' ... '0.25107816' '-0.08299783' '-0.62852347']
 ['/data/GZ_Decals/MGS_out_DECaLS/216.73248662726843_-2.558336707515353_0.262_grz_.fits'
  '6.4118094' '4.4545727' ... '-2.5403154' '0.4222854' '3.94778']
 ...
 ['/data/GZ_Decals/MGS_out_DECaLS/185.8959342163261_12.898011042802942_0.262_grz_.fits'
  '4.196303' '1.0886736' ... '-1.9845552' '-7.5776615' '0.40705034']
 ['/data/GZ_Decals/MGS_out_DECaLS/130.09162761307095_27.113144869991118_0.262_grz_.fits'
  '0.29514536' '0.16261218' ... '0.36543643' '0.21575616' '0.37621558']
 ['/data/GZ_Decals/MGS_out_DECaLS/209.69722323971683_-2.796119143808119_0.262_grz_.fits'
  '1.0389524' '-2.2737217' ... '-0.7585088' '0.94853944' '1.6581957']]


In [11]:
print(dd.shape)
print(dd)

(3434, 41)
[['/data/GZ_Decals/nomerge/175.8905862789053_21.66981935317594_0.fits'
  '0.006995933' '1.0993093' ... '0.3568548' '-0.008556896' '0.49385735']
 ['/data/GZ_Decals/nomerge/176.14777944132385_6.43021218747715_0.fits'
  '-0.45819253' '0.1510943' ... '-0.019727178' '-0.26700974' '0.3930761']
 ['/data/GZ_Decals/nomerge/216.73420321996912_-2.5567049269072277_0.fits'
  '-3.1298068' '-0.26893172' ... '1.1139963' '1.4247977' '4.353631']
 ...
 ['/data/GZ_Decals/nomerge/185.89469726554452_12.897557445261423_0.fits'
  '0.84842736' '0.8459865' ... '-6.7031136' '-3.6388366' '5.3683214']
 ['/data/GZ_Decals/nomerge/130.09191714259262_27.113977190304865_0.fits'
  '-0.36364064' '0.03688038' ... '0.25311407' '0.038813382' '0.21880837']
 ['/data/GZ_Decals/nomerge/209.6963186915605_-2.794944672696973_0.fits'
  '-0.8575084' '-0.14070234' ... '0.79517335' '-0.75291526' '1.2931647']]
