# Use multichannel Trained GAN

last edited: Nov 3rd, 2020

In [1]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torch.nn.functional as F
from photutils import create_matching_kernel
from skimage.transform import downscale_local_mean
from scipy.ndimage import zoom
import matplotlib.pyplot as plt
import astropy.io.fits as pyfits
import numpy as np
import h5py

import astropy.wcs as wcs
from astropy import units as u
from astropy.coordinates import SkyCoord
from scipy import ndimage
from PIL import Image

import cv2


%matplotlib inline

#### Read in GAN generator and trained weights

In [3]:
device = torch.device("cpu")
ngpu = int(3)
nz = int(100)
ngf = int(64)
ndf = int(64)
nc=7


class Shoobygen(nn.Module):

    def __init__(self,ngpu):
        super(Shoobygen, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(

            nn.Conv2d(nc, ngf * 4, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            
            
            nn.ConvTranspose2d( ngf*4, ngf * 8, 3, 3, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 7, 1, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),

            
            nn.ConvTranspose2d(ngf*4, nc, 4, 1, 0, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        if input.is_cuda and self.ngpu > 1:
            output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
            output1 = output[:,:,:,:]

        else:
            
            output = self.main(input)
            output1 = output[:,:,:,:]

        return output1


netS = Shoobygen(ngpu).to(device)
netS.load_state_dict(torch.load('netG_epoch_999.pth',map_location='cpu'))


<All keys matched successfully>

Read GOODS-S sample data in seven bands 

In [4]:
from galblend import *
import torchvision.transforms as transforms
device = torch.device("cpu")
nc = 7

hi_psfs = ['psf_b.fits','psf_v.fits', 'psf_i.fits','psf_i.fits', 'psf_z.fits', 'psf_j.fits', 'psf_h.fits']
lo_psfs = ['PSF_subaru_i.fits','PSF_subaru_i.fits','PSF_subaru_i.fits','PSF_subaru_i.fits',
           'PSF_subaru_i.fits','PSF_subaru_i.fits','PSF_subaru_i.fits']

kernel = np.zeros((41,41,1,7))
for i in range(len(hi_psfs)):
    psf = pyfits.getdata('../psfs/'+hi_psfs[i])
    psf = downscale_local_mean(psf,(3,3))
    psf = psf[7:-8,7:-8]

    psf_hsc = pyfits.getdata('../psfs/'+lo_psfs[i])
    psf_hsc = psf_hsc[1:42,1:42]    
    kern = create_matching_kernel(psf,psf_hsc)
    psfh = np.repeat(kern[:,:, np.newaxis], 1, axis=2)
    kernel[:,:,:,i] = psfh

kernel = torch.Tensor(kernel)
kernel = kernel.permute(2,3,0,1)
kernel =  kernel.float()
kernel = kernel.to(device)

tfms = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])
imaggge, data = galblend(gals=2, lim_hmag=25, plot_it=False)
pashe = np.zeros((1,7,64,64))
for chi in range(7):
    s = ndimage.rotate(imaggge[chi,:,:],0,mode='nearest',reshape=False)
    da = np.arcsinh(s)
    pash = (255.0 / (da.max()+0.1) * (da - da.min())).astype(np.uint8)
    pashe[0,chi,:,:] = tfms(pash)
mm = np.zeros((1,7,64,64))
mm[0,...]  = pashe
real_cpu = torch.Tensor(mm).float()
#im = real_cpu+0.1*torch.rand_like(real_cpu)


img2 = torch.tensor(np.zeros((1,7,22,22)))
for ch in range(real_cpu.shape[1]):
    imagetoconvolv = real_cpu[:,ch,:,:].reshape(-1,1,64,64)
    kerneltoconvolv = kernel[:,ch,:,:].reshape(-1,1,41,41)
    a = F.conv2d(imagetoconvolv, kerneltoconvolv,padding = 21) ## convolve with kernel
    img2[:,ch,:,:] = (F.upsample(a,scale_factor=1/3,mode='bilinear')).reshape(-1,22,22) ### fix pixel scale
    img2[:,ch,:,:] = img2[:,ch,:,:]+0.25*torch.rand_like(img2[:,ch,:,:])
            
 
    
img = img2.view(-1,7,22,22)
img = img[:,:,:,:].float()

fake = netS(img)
print(fake.shape)
fd = fake.detach()
fd = fd.cpu()

plt.figure(figsize=(16,10))

filts = ['450nm','606nm','750nm','814nm','850nm','1250nm','1600nm']
for i in range(7):
        
    plt.subplot(4,7,i+1)
    plt.imshow(mm[0,i,:,:],origin='lower')
    plt.text(10,10,filts[i],color='y',size=16)
    
    plt.xticks([])
    plt.yticks([])
    if i ==0:
        plt.ylabel('High Res',size=20)
    
    plt.subplot(4,7,7+i+1)
    plt.imshow(img[0,i,:,:],origin='lower')
    plt.xticks([])
    plt.yticks([])
    if i ==0:
        plt.ylabel('Low Res',size=20)   
        
    plt.subplot(4,7,14+i+1)
    plt.imshow((fd[0,i,:,:]),origin='lower')
    plt.xticks([])
    plt.yticks([])
    if i ==0:
        plt.ylabel('GAN Res',size=20)  
        
    plt.subplot(4,7,21+i+1)
    plt.imshow(real_cpu[0,i,:,:]-fd[0,i,:,:],origin='lower',cmap='gray')
    plt.xticks([])
    plt.yticks([])
    if i ==0:
        plt.ylabel('Residual',size=20)

    
plt.tight_layout()
#plt.savefig('../plots/multi.png')

  "See the documentation of nn.Upsample for details.".format(mode))


ValueError: too many values to unpack (expected 2)

In [None]:
galblend(gals=2,lim_hmag=24,plot_it=True)
