In [331]:
import numpy as np
import cv2

import time
import torch

from torch import nn
from math import sqrt

from typing import Dict

from torch._C import dtype

from torchvision import transforms

from collections import OrderedDict

import tqdm

from typing import Any, BinaryIO, List, Optional, Tuple, Union

import pathlib

from types import FunctionType

from PIL import Image, ImageColor, ImageDraw, ImageFont

import os
import argparse

import random

In [332]:
layer_size = 28
num_layers = 5
image_w = 224
image_h = 224
w0 = 30.0
w0_initial = 30.0
learning_rate = 5e-4
num_iters = 5000
device = "cuda:0"

In [333]:
class Sine(nn.Module):
    """Sine activation with scaling.
    Args:
        w0 (float): Omega_0 parameter from SIREN paper.
    """
    def __init__(self, w0=1.):
        super().__init__()
        self.w0 = w0

    def forward(self, x):
        return torch.sin(self.w0 * x)


class SirenLayer(nn.Module):
    """Implements a single SIREN layer.
    Args:
        dim_in (int): Dimension of input.
        dim_out (int): Dimension of output.
        w0 (float):
        c (float): c value from SIREN paper used for weight initialization.
        is_first (bool): Whether this is first layer of model.
        use_bias (bool):
        activation (torch.nn.Module): Activation function. If None, defaults to
            Sine activation.
    """
    def __init__(self, dim_in, dim_out, w0=30., c=6., is_first=False,
                 use_bias=True, activation=None):
        super().__init__()
        self.dim_in = dim_in
        self.is_first = is_first

        self.linear = nn.Linear(dim_in, dim_out, bias=use_bias)

        # Initialize layers following SIREN paper
        w_std = (1 / dim_in) if self.is_first else (sqrt(c / dim_in) / w0)
        nn.init.uniform_(self.linear.weight, -w_std, w_std)
        if use_bias:
            nn.init.uniform_(self.linear.bias, -w_std, w_std)

        self.activation = Sine(w0) if activation is None else activation

    def forward(self, x):
        out = self.linear(x)
        out = self.activation(out)
        return out


class Siren(nn.Module):
    """SIREN model.
    Args:
        dim_in (int): Dimension of input.
        dim_hidden (int): Dimension of hidden layers.
        dim_out (int): Dimension of output.
        num_layers (int): Number of layers.
        w0 (float): Omega 0 from SIREN paper.
        w0_initial (float): Omega 0 for first layer.
        use_bias (bool):
        final_activation (torch.nn.Module): Activation function.
    """
    def __init__(self, dim_in, dim_hidden, dim_out, num_layers, w0=30.,
                 w0_initial=30., use_bias=True, final_activation=None):
        super().__init__()
        layers = []
        for ind in range(num_layers):
            is_first = ind == 0
            layer_w0 = w0_initial if is_first else w0
            layer_dim_in = dim_in if is_first else dim_hidden

            layers.append(SirenLayer(
                dim_in=layer_dim_in,
                dim_out=dim_hidden,
                w0=layer_w0,
                use_bias=use_bias,
                is_first=is_first
            ))

        self.net = nn.Sequential(*layers)

        final_activation = nn.Identity() if final_activation is None else final_activation
        self.last_layer = SirenLayer(dim_in=dim_hidden, dim_out=dim_out, w0=w0,
                                use_bias=use_bias, activation=final_activation)

    def forward(self, x):
        x = self.net(x)
        return self.last_layer(x)

In [334]:
def to_coordinates_and_features(img):
    """Converts an image to a set of coordinates and features.
    Args:
        img (torch.Tensor): Shape (channels, height, width).
    """
    # Coordinates are indices of all non zero locations of a tensor of ones of
    # same shape as spatial dimensions of image
    coordinates = torch.ones(img.shape[1:]).nonzero(as_tuple=False).float()
    #coordinates = torch.ones(img.shape[1:]).float()
    # Normalize coordinates to lie in [-.5, .5]
    coordinates = coordinates / (img.shape[1] - 1) - 0.5
    # Convert to range [-1, 1]
    coordinates *= 2
    # Convert image to a tensor of features of shape (num_points, channels)
    features = img.reshape(img.shape[0], -1).T
    return coordinates, features

In [335]:
def psnr(img1, img2):
    """Calculates PSNR between two images.
    Args:
        img1 (torch.Tensor):
        img2 (torch.Tensor):
    """
    return 20. * np.log10(1.) - 10. * (img1 - img2).pow(2).mean().log10().to('cpu').item()

def get_clamped_psnr(img_recon, img):
    """Get PSNR between true image and reconstructed image. As reconstructed
    image comes from output of neural net, ensure that values like in [0, 1] and
    are unsigned ints.
    Args:
        img (torch.Tensor): Ground truth image.
        img_recon (torch.Tensor): Image reconstructed by model.
    """
    return psnr(clamp_image(img_recon), img)

In [336]:
def clamp_image(img):
    """Clamp image values to like in [0, 1] and convert to unsigned int.
    Args:
        img (torch.Tensor):
    """
    # Values may lie outside [0, 1], so clamp input
    img_ = torch.clamp(img, 0., 1.)
    # Pixel values lie in {0, ..., 255}, so round float tensor
    return torch.round(img_ * 255) / 255.

In [337]:
def is_image_file(filename): # Compares 'filename' extension to common image file types.
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])
def load_image_path(imgDir):

    all_training_files=os.walk(imgDir)
    #print(all_training_files)
    train_files=[]
    train_imageNames=[]
    train_nSamples=0
    for path,direction,filelist in all_training_files:
        files = [file for file in filelist if os.path.isfile(os.path.join(path, file))]
        #print(files)
        imageNames = [file.split('.')[0] for file in files if is_image_file(file)]
        files = [os.path.join(path, file) for file in files if is_image_file(file)]
        train_files.append(files)
        train_imageNames.append(imageNames)
        train_nSamples=train_nSamples+len(files)
    train_files=sum(train_files,[])
    train_imageNames=sum(train_imageNames,[])
    #print(train_imageNames[0])
    #print(train_files[0])
    print(train_nSamples)
    
    return train_files, train_imageNames

In [338]:
#imgDir = '/export/hdd/scratch/dataset/mini_imagenet/'
imgDir_ref = '/export/hdd/scratch/dataset/mini_imagenet_raw/n01614925'
decode_image_path_ref, train_image_names_ref = load_image_path(imgDir_ref)
transform_size = transforms.Resize([224,224])
psnr_value = np.zeros([10])

600


In [339]:
for i in range (10):
    img = cv2.imread(decode_image_path_ref[i])
    
    dim = (224, 224)
    img = cv2.resize(img, dim, interpolation = cv2.INTER_AREA)
    img = np.transpose(img, (2, 0, 1))
    #img = cv2.imread("./cifar_10_images/train_cifar10/%d.jpg"%(i))
    #img = np.transpose(img, (2, 0, 1))
    #image_reference[i] = img
    #img = Image.open(decode_image_path_ref[4]).convert('RGB')
    
    #img = img.resize((224,224))
    #img = img.resize((224,224),Image.ANTIALIAS)
    #print(img)
    #img_tensor_tran = transforms_test_2(img)
    #image_reference = np.asarray(img, dtype = np.float32)
    image_tensor_transformed = torch.from_numpy(img)
    #image_tensor_transformed = image_tensor.permute(2,0,1)
    #image_tensor_transformed = transform_size(image_tensor_1)
    print(image_tensor_transformed.shape)
    #psnr_value[i] = psnr(img_transformed/255.0, image_tensor_1_transformed/255.0)
    
    batch_size = 1
    img = torch.zeros([ 3,224,224])
    coordinates = torch.zeros([img.shape[1] * img.shape[2], 2])
    for j in range (batch_size):
        coordinates, _= to_coordinates_and_features(img)
    
    func_rep = Siren(
            dim_in=2,
            dim_hidden=layer_size,
            dim_out=3,
            num_layers=num_layers,
            final_activation=torch.nn.Identity(),
            w0_initial=w0_initial,
            w0=w0
        ).to(device)
    func_rep = func_rep.half()
    #func_rep.eval()
    #PATH = '/usr/scratch/hchen799/INR/weights_ImageNet/trial_13x49_30fq_5000ep_10im/best_model_4.pt'
    func_rep.load_state_dict(torch.load('/usr/scratch/hchen799/INR/weights_ImageNet/trial_5x28_30fq_5000ep_10im/best_model_%d.pt'%(i)))
    coordinates = coordinates.to(device)
    coordinates = coordinates.half()
    output = func_rep(coordinates)
    #output = output[:,[2,1,0]]
    output = output.reshape(224,224,3)
    output = clamp_image(output)
    output = output.permute(2,0,1)
    output = output.cpu()
    #print(output * 255)
    #print("the raw image is ",image_tensor_transformed)
    psnr_value[i] = psnr(output, image_tensor_transformed/255.0)

torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])
torch.Size([3, 224, 224])


In [340]:
psnr_total = 0
for i in range (10):
    psnr_total = psnr_total + psnr_value[i]
psnr_average = psnr_total/10
print(psnr_average)

27.31619930267334


In [22]:
batch_size = 1
img = torch.zeros([ 3,224,224])
coordinates = torch.zeros([img.shape[1] * img.shape[2], 2])
for i in range (batch_size):
    coordinates, _= to_coordinates_and_features(img)

In [23]:
print(coordinates.shape)

torch.Size([50176, 2])


In [None]:
func_rep.load_state_dict(torch.load(PATH))