In [12]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 22 20:18:19 2021

@author: xuquanfeng
"""
from PIL import Image
import torch
from torchvision import datasets,transforms,utils,models
from VAE_model.models import VAE, MyDataset
import torch.nn as nn
from torch.autograd import Variable
import numpy as np
import random
import os
import datetime
import torchvision
import torch.nn.functional as F
from astropy.io import fits
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from torch import optim
#设置随机种子
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True
    os.environ['PYTHONHASHSEED'] = str(seed)
setup_seed(10)
# Hyper parameters
if not os.path.exists('./model'):
    os.mkdir('./model')
if not os.path.exists('./train_proces'):
    os.mkdir('./train_proces')
num_epochs = 20   #循环次数
batch_size = 128    #每次投喂数据量
learning_rate = 0.00001   #学习率
num_var = 40
momentum = 0.8
k = 1

print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

torch.cuda.empty_cache()
model = torch.load('/data/xqf/VAE1/model/vae_40_best.pth')

# print(model)
# Device configuration  判断能否使用cuda加速
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

reconstruction_function = nn.MSELoss(size_average=False)

def loss_function(recon_x, x, mu, logvar):
    """
    recon_x: generating images
    x: origin images
    mu: latent mean
    logvar: latent log variance
    """
    BCE = reconstruction_function(recon_x, x)  # mse loss
    # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.sum(KLD_element).mul_(-0.5)
    # KL divergence
    return BCE + k*KLD


PyTorch Version:  1.10.0+cu113
Torchvision Version:  0.11.1+cu113


In [10]:
train_loss11 = open('./train_proces/train_'+str(num_var)+'_'+str(k)+'.txt', 'w')
train_data = MyDataset(datatxt='train_tal.txt', transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size = batch_size, shuffle=True,num_workers=20)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
strattime = datetime.datetime.now()
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, data in enumerate(train_loader):
        img,oimg,_ = data
        img = Variable(img)
        img = img.to(device)
        oimg = Variable(oimg)
        oimg = oimg.to(device)
        optimizer.zero_grad()
        mu, _ = model.encode(img)
        omu, _ = model.encode(oimg)
        loss = reconstruction_function(mu, omu)
        loss.backward()
        # train_loss += loss.data[0]
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % 10 == 0:
            endtime = datetime.datetime.now()
            asd = str('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} time:{:.2f}s'.format(
                epoch,
                batch_idx * len(img),
                len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(img),
                (endtime-strattime).seconds))
            print(asd)
            train_loss11.write(asd+'\n')
            # torch.save(model, './model/b_vae'+str(epoch)+'_'+str(batch_idx)+'.pth')
    if epoch == 0:
        best_loss = train_loss / len(train_loader.dataset)
    if epoch > 0 and best_loss > train_loss / len(train_loader.dataset):
        best_loss = train_loss / len(train_loader.dataset)
        asds = 'Save Best Model!'
        print(asds)
        train_loss11.write(asds+'\n')
        torch.save(model, './model/vae_'+str(num_var)+'_'+str(k)+'_best.pth')
    asds = str('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))
    print(asds)
    train_loss11.write(asds+'\n')
train_loss11.close()
# if epoch == num_epochs-1:
#     torch.save(model, './model/vae_'+str(num_var)+'_'+str(k)+'.pth')


====> Epoch: 0 Average loss: 103.7544
Save Best Model!
====> Epoch: 1 Average loss: 54.6232
Save Best Model!
====> Epoch: 2 Average loss: 33.2189
Save Best Model!
====> Epoch: 3 Average loss: 23.0032
Save Best Model!
====> Epoch: 4 Average loss: 16.7678
Save Best Model!
====> Epoch: 5 Average loss: 13.3860
Save Best Model!
====> Epoch: 6 Average loss: 10.7953
Save Best Model!
====> Epoch: 7 Average loss: 8.8102
Save Best Model!
====> Epoch: 8 Average loss: 7.6619
Save Best Model!
====> Epoch: 9 Average loss: 6.4931
Save Best Model!
====> Epoch: 10 Average loss: 5.6143
Save Best Model!
====> Epoch: 11 Average loss: 5.1796
Save Best Model!
====> Epoch: 12 Average loss: 4.4631
Save Best Model!
====> Epoch: 13 Average loss: 3.9613
Save Best Model!
====> Epoch: 14 Average loss: 3.7232
Save Best Model!
====> Epoch: 15 Average loss: 3.3352
Save Best Model!
====> Epoch: 16 Average loss: 3.1952
Save Best Model!
====> Epoch: 17 Average loss: 2.8434
Save Best Model!
====> Epoch: 18 Average loss: 

In [15]:
train_data = MyDataset(datatxt='train_tal.txt', transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size = batch_size, shuffle=False,num_workers=20)

if not os.path.exists('./result'):
    os.mkdir('./result')
model.eval()
from tqdm import tqdm
sssi = []
with torch.no_grad():
    for batch_idx, data in enumerate(tqdm(train_loader)):
        img,oimg,fn = data
        img = Variable(img)
        img = img.to(device)
        oimg = Variable(oimg)
        oimg = oimg.to(device)
        optimizer.zero_grad()
        mu, _ = model.encode(img)
        omu, _ = model.encode(oimg)

        for i in range(len(img)):
            qw = [fn[0][i]]
            qw.extend(mu[i].cpu().detach().numpy())
            qw.extend(fn[1][i])
            qw.extend(omu[i].cpu().detach().numpy())
            sssi.append(qw)

    dd = np.array(sssi)
    print(len(dd))
    # np.save(pt+'result_ssim.npy',dd)
    np.save('./result/resu_'+str(num_var)+'_'+str(k)+'_all.npy',dd)

100%|██████████| 27/27 [00:06<00:00,  4.43it/s]


3434
