## Implicit Zoo Data visualization 
A quick start notebook for performing INRs visualization 

In [None]:
# if not installed already, please setup all dependencies.
! git clone https://github.com/qimaqi/Implicit-Zoo.git

In [None]:
cd Implicit-Zoo

## Get data samples
We can download some data demo samples and place it it dataset/

In [None]:
!mkdir dataset

In [None]:
!wget https://www.dropbox.com/scl/fi/7poo46l7dgvfk6mmuk1s4/demo_data.zip?rlkey=lxbbk20vn45pjego266h0aaz1&st=iodkjet4&dl=0 -O demo_data.zip

In [None]:
!unzip demo_data.zip -d dataset

## Import Dependencies

In [None]:
import math
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from typing import List, Tuple, Union
import numpy as np 
from torchvision import transforms
import matplotlib.pyplot as plt

## Create INRs 

### 

In [None]:

def make_coordinates(
    shape: Union[Tuple[int], List[int]],
    bs: int,
    coord_range: Union[Tuple[int], List[int]] = (-1, 1),
) -> torch.Tensor:
    y_coordinates = np.linspace(coord_range[0], coord_range[1], shape[0])
    x_coordinates = np.linspace(coord_range[0], coord_range[1], shape[1])
    x_coordinates, y_coordinates = np.meshgrid(x_coordinates, y_coordinates)
    x_coordinates = x_coordinates.flatten()
    y_coordinates = y_coordinates.flatten()
    coordinates = np.stack([x_coordinates, y_coordinates]).T
    coordinates = np.repeat(coordinates[np.newaxis, ...], bs, axis=0)
    return torch.from_numpy(coordinates).type(torch.float)

class Sine(nn.Module):
    def __init__(self, w0=1.0):
        super().__init__()
        self.w0 = w0

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return torch.sin(self.w0 * x)


class Siren(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        w0=30.0,
        c=6.0,
        is_first=False,
        use_bias=True,
        activation=None,
    ):
        super().__init__()
        self.w0 = w0
        self.c = c
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.is_first = is_first

        weight = torch.zeros(dim_out, dim_in)
        bias = torch.zeros(dim_out) if use_bias else None
        self.init_(weight, bias, c=c, w0=w0)

        self.weight = nn.Parameter(weight)
        self.bias = nn.Parameter(bias) if use_bias else None
        self.activation = Sine(w0) if activation is None else activation

    def init_(self, weight: torch.Tensor, bias: torch.Tensor, c: float, w0: float):
        dim = self.dim_in

        w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0)
        weight.uniform_(-w_std, w_std)

        if bias is not None:
            # bias.uniform_(-w_std, w_std)
            bias.zero_()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = F.linear(x, self.weight, self.bias)
        out = self.activation(out)
        return out


class INR(nn.Module):
    def __init__(
        self,
        image_size,
        in_dim: int = 2,
        n_layers: int = 3,
        up_scale: int = 16,
        out_channels: int = 1,
        device='cpu',
    ):
        super().__init__()
        self.image_size = np.array(image_size)
        self.device = device
        hidden_dim = in_dim * up_scale
        self.layers = [Siren(dim_in=in_dim, dim_out=hidden_dim)]
        for i in range(n_layers - 2):
            self.layers.append(Siren(hidden_dim, hidden_dim))
        self.layers.append(nn.Linear(hidden_dim, out_channels))
        self.seq = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.seq(x) 
    
    @torch.no_grad()
    def predict_entire_image(self):
        input = make_coordinates(self.image_size, 1).to(self.device)
        image = self.forward(input)
        image = image.view(*self.image_size, -1)
        image = image.permute(2, 0, 1).detach().cpu()
        return image


## Visualize Cifar data

In [None]:
siren_model = INR(
    image_size=(32,32),
    in_dim=2,
    n_layers=3,
    up_scale=32,
    out_channels=3,
)
# load a checkpoint from dataset
path_to_checkpoint = 'dataset/demo_data/cifar_demo.ckpt'
siren_model.load_state_dict(torch.load(path_to_checkpoint, map_location='cpu')['params'])
image = siren_model.predict_entire_image()

# save image for visualization
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])
inv_normalize = transforms.Compose([    
    transforms.Normalize(mean = torch.zeros_like(mean),
                        std = 1/std),
    transforms.Normalize(mean = -mean,
                        std = torch.ones_like(std)),
                        ])
                        
image = inv_normalize(image)
image = torch.clamp(image, 0, 1)
pil_image = transforms.ToPILImage()(image)

fig, axs = plt.subplots(1, 1, figsize=(10,10)) 
axs.imshow(pil_image)
axs.set_title('Recunstracted Images from INRs')
plt.show()

## Visualize ImageNet data

In [None]:
siren_model = INR(
    image_size=(256,256),
    in_dim=2,
    n_layers=4,
    up_scale=128,
    out_channels=3,
)
# load a checkpoint from dataset
path_to_checkpoint = 'dataset/demo_data/imagenet_demo.ckpt'
siren_model.load_state_dict(torch.load(path_to_checkpoint, map_location='cpu')['params'])
image = siren_model.predict_entire_image()
                        
image = inv_normalize(image)
image = torch.clamp(image, 0, 1)
pil_image = transforms.ToPILImage()(image)

fig, axs = plt.subplots(1, 1, figsize=(10,10)) 
axs.imshow(pil_image)
axs.set_title('Recunstracted Images from INRs')
plt.show()

## Visualize Cityscapes data

In [None]:
siren_model = INR(
    image_size=(320,640),
    in_dim=2,
    n_layers=5,
    up_scale=128,
    out_channels=3,
)
# load a checkpoint from dataset
path_to_checkpoint = 'dataset/demo_data/cityscapes_demo.ckpt'
siren_model.load_state_dict(torch.load(path_to_checkpoint, map_location='cpu')['params'])
image = siren_model.predict_entire_image()

image = inv_normalize(image)
image = torch.clamp(image, 0, 1)
pil_image = transforms.ToPILImage()(image)

fig, axs = plt.subplots(1, 1, figsize=(10,20)) 
axs.imshow(pil_image)
axs.set_title('Recunstracted Images from INRs')
plt.show()