In [1]:
import codecs
import os
import random
from astropy.io import fits
import numpy as np
import matplotlib.pyplot as plt

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import Dataset

In [2]:
def path_list(paths):
    path_list = os.listdir(paths)
    path_list.sort(key=lambda x:int(x.split('.fits')[0]))
    img_list = list()
    for file in path_list:
        path = os.path.join(paths+file)
        img_list.append(path)
    return img_list

def load_image(path): 
    hdu = fits.open(path)
    img = hdu[0].data
    img = np.array(img,dtype=np.float32)
    img.resize(128,128)
    hdu.close()
    return img

def load_params(input_txt):
    labels = []
    f = codecs.open(input_txt,mode='r',encoding='utf-8')
    lines = f.readlines()
    for line in lines:
        labels.append(line.split()[2])
    f.close()
    return labels

def maxminnorm(array):
    imgmax = array.max(axis=0)
    imgmin = array.min(axis=0)
    data_shape = array.shape
    data_rows = data_shape[0]
    data_cols = data_shape[1]
    t=np.empty((data_rows,data_cols))
    for i in range(data_cols):
        t[:,i]=(array[:,i]-imgmin[i])/(imgmax[i]-imgmin[i])
    t = np.array(t,dtype=np.float32)
    return t

def addnoise(img):
    '''
    mean=0, variance=1.0
    '''
    N = np.random.randn(128,128)
    img = N + img            
    return img

def batch_data(batch,batch_size):
    images = batch[0].to(torch.float32)
    
    labels = np.asarray(batch[1],dtype=np.float32)
    labels = torch.from_numpy(labels).view(batch_size,-1)
    return images,labels

In [9]:
def mean_std(img):
    mean = np.mean(img)
    std = np.std(img,ddof=1)
    return mean,std
def meanstd(img_list):
    mean1 = 0
    std1 = 0
    img_num = 1000
    img1 = random.sample(img_list,img_num)
    for i in img1:
        img = load_image(i)
        img_noise = addnoise(img)
        img_norm = maxminnorm(img_noise)
        mean = mean_std(img_norm)[0]
        std = mean_std(img_norm)[1]
        mean1 += mean
        std1 += std
    return mean1/img_num,std1/img_num

In [13]:
img_list = path_list('/home/zqw/GalSim/a_mycode/data/train_data/')
meanstd(img_list)

(0.23363281835243105, 0.16072264671325684)

In [11]:
class Mydataset(Dataset):
    def __init__(self, file_path, label_path):
        # parameters
        self.label = load_params(label_path)
        self.data_len = len(self.label)
        
        # fits_images_train
        self.img_list = path_list(file_path)
        self.mean = meanstd(self.img_list)[0]
        self.std = meanstd(self.img_list)[1]
        self.transforms = transforms.Compose([
                            transforms.ToPILImage(),
                              transforms.RandomVerticalFlip(p=0.5),
                              transforms.ToTensor(),
                             transforms.Normalize(self.mean, self.std)])
       
    def __getitem__(self,index):
        item = self.img_list[index]
        label = self.label[index]
        
        img = load_image(item)
        img = maxminnorm(addnoise(img))
        img = self.transforms(img)        
        return img, label
    
    def __len__(self):
        return self.data_len


In [12]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # conv1:conv2d -> conv2d -> maxpool -> dropout
        # 128*128 -> 129*129
        self.conv1 = nn.Sequential(
             nn.Conv2d(1, 2, 4, stride=1, padding=2), 
             nn.BatchNorm2d(2),
             nn.ReLU(),
             nn.Conv2d(2, 4, 4, stride=1, padding=2),
             nn.BatchNorm2d(4),
             nn.ReLU(),
             nn.MaxPool2d(2, 2))
        
        # conv2:conv2d -> conv2d -> maxpool -> dropout
        # 128*128 -> 33*33
        self.conv2 = nn.Sequential(
             nn.Conv2d(4, 8, 4, padding=2),
             nn.BatchNorm2d(8),
             nn.ReLU(),
             nn.Conv2d(8, 16, 4, padding=2),
             nn.BatchNorm2d(16),
             nn.ReLU(),
             nn.MaxPool2d(2, 2))
        
        # conv3:conv2d -> conv2d -> maxpool
        # 33*33 -> 16*16
        self.conv3 = nn.Sequential(
             nn.Conv2d(16, 32, 3, padding=1),
             nn.BatchNorm2d(32),
             nn.ReLU(),
             nn.Conv2d(32, 64, 3, padding=1),
             nn.BatchNorm2d(64),
             nn.ReLU(),
             nn.MaxPool2d(2, 2))
        
        # conv4:conv2d -> conv2d -> maxpool
        # 16*16 -> 9*9
        self.conv4 = nn.Sequential(
             nn.Conv2d(64, 128, 2, padding=1),
             nn.BatchNorm2d(128),
             nn.ReLU(),
             nn.Conv2d(128, 256, 2, padding=1),
             nn.BatchNorm2d(256),
             nn.ReLU(),
             nn.MaxPool2d(2, 2))
             
        # fully connected layer
        # 256*9*9 -> 128 -> 64 -> 1
        self.fc = nn.Sequential(
             nn.Linear( 256* 9 * 9, 128),
             nn.BatchNorm1d(128),
             nn.Linear(128, 64),
             nn.Linear(64, 1))
    def forward(self,x):
        #x = x.to(torch.float32)
        
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = F.relu(self.fc(x.reshape(-1, 256*9*9)))
        output = x
        return output