CMSC 25500 Final Project: Pierce Hoenigman

Evaluation of Out-of-Bounds Pattern Conservation in Implicit Neural Representation Networks

In [None]:
#@title Imports

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Resize, Compose, ToTensor, Normalize
import torchaudio

from PIL import Image
import skimage

import numpy as np
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'


In [None]:
#@title Flexible implicit representation network
#Sitzmann et al.: https://arxiv.org/pdf/2006.09661, github: https://github.com/vsitzmann/siren/tree/master
#Tancik et al.: https://arxiv.org/pdf/2006.10739, github: https://github.com/tancik/fourier-feature-networks/blob/master/Demo.ipynb
#Mehrabian et al.: https://arxiv.org/pdf/2409.09323, github: https://github.com/Ali-Meh619/FKAN/blob/main/FKAN_INR.ipynb


class Sinusoidal(nn.Module):
  #inspired by Sitzmann
  def __init__(self, input_dim, output_dim, freq=30, freq_first=False):
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.omega = freq
    self.omega_first = freq_first

    self.layer = nn.Linear(input_dim, output_dim, bias=True)

    with torch.no_grad():
      if self.omega_first:
          self.layer.weight.uniform_(-1/self.input_dim, 1/self.output_dim)
      else:
          self.layer.weight.uniform_(-np.sqrt(6/self.input_dim) / self.omega,
                                      np.sqrt(6/self.input_dim) / self.omega)
  def forward(self, x):
    return torch.sin(self.omega * self.layer(x))

class Radial(nn.Module):
  #inspired by Sitzmann
  def __init__(self, input_dim, output_dim):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim

        self.centers = nn.Parameter(torch.Tensor(output_dim, input_dim))
        nn.init.uniform_(self.centers, -1, 1) #bounds between -1 and 1
        self.sigs = nn.Parameter(torch.Tensor(output_dim))
        nn.init.constant_(self.sigs, 10)

  def forward(self, input):
      input = input[0, ...]
      size = (input.size(0), self.output_dim, self.input_dim)
      x = input.unsqueeze(1).expand(size)
      c = self.centers.unsqueeze(0).expand(size)
      distances = (x - c).pow(2).sum(-1) * self.sigs.unsqueeze(0)
      return self.gaussian(distances).unsqueeze(0)

  def gaussian(self, x):
      return torch.exp(-1 * x**2)

class Fourier(nn.Module):
  #inspired by Mehrabian
  #using Chebyshev polynomials rather than splines
  def __init__(self, input_dim, output_dim, gridsize):
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.gridsize = gridsize

    self.coeffs = torch.nn.Parameter(torch.randn(2, output_dim, input_dim, gridsize) /
                                            (np.sqrt(input_dim) * np.sqrt(gridsize)))
    self.bias = torch.nn.Parameter(torch.zeros(1, output_dim))

  def forward(self, x):
    k = torch.reshape(torch.arange(1, self.gridsize+1, device=x.device), (1,1,1,self.gridsize))
    x = torch.reshape(x, (x.shape[1],1,x.shape[2],1))

    out =  torch.sum(torch.cos(k*x) * self.coeffs[0:1],(-2,-1))
    out += torch.sum(torch.sin(k*x) * self.coeffs[1:2],(-2,-1))
    out += self.bias
    out = torch.reshape(out, (-1, self.output_dim))
    return out

class Flex_Model(nn.Module):
  #inspired by Sitzmann
  def __init__(self, input_dim, output_dim, hidden_dim=256, nhidden=3, freq=30, activation='sin'):
    super().__init__()
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.hidden_dim = hidden_dim
    self.nhidden = nhidden
    self.freq = freq
    self.activation = activation

    if activation=='sin':
      self.modellst = nhidden * [Sinusoidal(hidden_dim, hidden_dim, freq=freq, freq_first=False)]
      self.modellst.insert(0, Sinusoidal(input_dim, hidden_dim, freq=freq, freq_first=True))
      finallayer = nn.Linear(hidden_dim, output_dim)
      with torch.no_grad():
        finallayer.weight.uniform_(-np.sqrt(6 / hidden_dim) / freq, np.sqrt(6 / hidden_dim) / freq)
      self.modellst.append(finallayer)
      #self.modellst.append(Sinusoidal(hidden_dim, output_dim, freq=freq, freq_first=False))

    elif activation in ['relu', 'ffn']:
      if activation == 'ffn':
        #inspired by Tancik
        assert input_dim % 2 == 0, 'for ffn, input dim needs to be even'
        self.B = torch.randn((input_dim // 2,2)).to(device) * 10
      self.modellst = nhidden * [nn.Linear(hidden_dim, hidden_dim, bias=True), nn.ReLU()]
      self.modellst.insert(0, nn.ReLU())
      self.modellst.insert(0, nn.Linear(input_dim, hidden_dim, bias=True))
      self.modellst.append(nn.Linear(hidden_dim, output_dim, bias=True))

    elif activation=='tanh':
      self.modellst = nhidden * [nn.Linear(hidden_dim, hidden_dim, bias=True), nn.Tanh()]
      self.modellst.insert(0, nn.Tanh())
      self.modellst.insert(0, nn.Linear(input_dim, hidden_dim, bias=True))
      self.modellst.append(nn.Linear(hidden_dim, output_dim, bias=True))

    elif activation=='rbf':
      self.modellst = nhidden * [nn.Linear(hidden_dim, hidden_dim, bias=True), nn.ReLU()]
      #self.modellst.insert(0, nn.ReLU())
      self.modellst.insert(0, Radial(input_dim, hidden_dim))
      self.modellst.append(nn.Linear(hidden_dim, output_dim, bias=True))

    elif activation=='fkan':
      #inspired by Mehrabian
      #the paper doesn't mention this but seems to use sin instead of ReLU; this seems like cheating
      #so I will also check that this works--if not, that's an interesting result as well
      self.gridsize = 100 #what they found worked well was 270
      self.modellst = nhidden * [nn.Linear(hidden_dim, hidden_dim, bias=True), nn.ReLU()]
      self.modellst.insert(0, nn.LayerNorm(hidden_dim))
      self.modellst.insert(0, Fourier(input_dim, hidden_dim, self.gridsize))
      self.modellst.append(nn.Linear(hidden_dim, output_dim, bias=True))

    else:
      raise ValueError('activation must be in [sin, relu, tanh, rbf, ffn, fkan]')

    self.model = nn.Sequential(*self.modellst)

  def forward(self, x):
    x = x.clone().detach().requires_grad_(True)
    if self.activation=='ffn':
      #inspired by Tancik
      x_proj = (2*np.pi*x) @ self.B.T
      x = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], axis=-1)
    output = self.model(x)
    return output, x


In [None]:
#@title Preprocessing utils

def get_img_tensor(fname, sidelength):
  #inspired by Sitzmann
  img = Image.open(fname)
  transform = Compose([
      Resize(sidelength),
      ToTensor(),
      Normalize(torch.Tensor([0.5]), torch.Tensor([0.5]))
  ])
  img = transform(img)
  return img

def get_mgrid(sidelen, dim=2):
  #inspired by Sitzmann
  tensors = tuple(dim * [torch.linspace(-1, 1, steps=sidelen)])
  mgrid = torch.stack(torch.meshgrid(*tensors), dim=-1)
  mgrid = mgrid.reshape(-1, dim)
  return mgrid

class PatternFitting(Dataset):
  #inspired by Sitzmann
  def __init__(self, fname, sidelength, selected_channel=None):
      super().__init__()
      img = get_img_tensor(fname, sidelength)
      #print(img.shape)
      self.pixels = img.permute(1, 2, 0) #.view(-1, 1)
      if selected_channel is not None:
        self.pixels = self.pixels[:,:,selected_channel]
        self.pixels = torch.reshape(self.pixels, (1,sidelength**2,1))
        #print(self.pixels.shape)
      self.coords = get_mgrid(sidelength, 2)

  def __len__(self):
      return 1

  def __getitem__(self, idx):
      if idx > 0: raise IndexError
      #print(self.coords.shape, self.pixels.shape)
      return self.coords, self.pixels

class AudioFitting(Dataset):
  def __init__(self, aud_file, selected_channel=None):
    self.aud_tensor, self.sample_rate = torchaudio.load(aud_file, format='mp3')
    self.channels, self.len = self.aud_tensor.shape
    self.aud_tensor = self.aud_tensor.reshape(1, self.len, self.channels)
    #print(self.channels, self.len)
    if selected_channel is not None:
      self.aud_tensor = self.aud_tensor[:, :, selected_channel].unsqueeze(2)
      #print(self.aud_tensor.shape)
    self.coords = torch.linspace(-1, 1 , steps=self.len)

  def get_time(self):
    return self.len

  def __len__(self):
    return 1

  def __getitem__(self, idx):
    if idx > 0: raise IndexError

    return self.coords, self.samples



In [None]:
#@title Image training utils

def train_siren(img_siren, dataloader, total_steps, img_dim, steps_til_summary=10, lr=1e-4):
  optim = torch.optim.Adam(lr=lr, params=img_siren.parameters())

  model_input, ground_truth = next(iter(dataloader))
  model_input, ground_truth = model_input.to(device), ground_truth.to(device)
  ground_truth = torch.squeeze(ground_truth)

  for step in range(total_steps):
    model_output, coords = img_siren(model_input)
    loss = F.mse_loss(torch.squeeze(model_output), ground_truth)

    if not step % steps_til_summary:
      near_cell = show_cell(img_siren, img_dim, xcell_idx=0, ycell_idx=1, magnif=1, just_return=True)
      near_loss = F.mse_loss(torch.squeeze(near_cell), ground_truth)
      far_cell = show_cell(img_siren, img_dim, xcell_idx=-10, ycell_idx=10, magnif=1, just_return=True)
      far_loss = F.mse_loss(torch.squeeze(far_cell), ground_truth)
      wide_area = show_cell(img_siren, img_dim, xcell_idx=0, ycell_idx=0, magnif=10, just_return=True)

      print(f"Step {step}, Total loss {loss:.6f}, Near loss {near_loss:.6f}, Far loss {far_loss:.6f}")

      fig, axes = plt.subplots(1,4, figsize=(16,4))
      axes[0].imshow(model_output.cpu().view(img_dim,img_dim).detach().numpy())
      axes[1].imshow(near_cell.cpu().cpu().view(img_dim,img_dim).detach().numpy())
      axes[2].imshow(far_cell.cpu().view(img_dim,img_dim).detach().numpy())
      axes[3].imshow(wide_area.cpu().view(img_dim,img_dim).detach().numpy())
      plt.show()

    optim.zero_grad()
    loss.backward()
    optim.step()

def show_cell(img_siren, img_dim=256, xcell_idx=0, ycell_idx=0, magnif=1, shift_factor = 1, just_return=False):
  with torch.no_grad():
    out_of_range_coords = get_mgrid(img_dim, 2) * magnif
    out_of_range_coords[:,0] = out_of_range_coords[:,0] + int(xcell_idx * shift_factor * img_dim)
    out_of_range_coords[:,1] = out_of_range_coords[:,1] + int(ycell_idx * shift_factor * img_dim)
    if torch.cuda.is_available():
      out_of_range_coords = out_of_range_coords.cuda()
    out_of_range_coords = torch.unsqueeze(out_of_range_coords, 0)
    model_out, _ = img_siren(out_of_range_coords)
    if just_return:
      return model_out

    fig, ax = plt.subplots(figsize=(16,16))
    ax.imshow(model_out.cpu().view(1024,1024).numpy())
    plt.show()

def eval_square(img_siren, img_dim, ground_truth, near_dist=1, sf=1):
  losses = []
  for cell in range(-near_dist, near_dist+1):
    c1 = show_cell(img_siren, img_dim, xcell_idx=-near_dist, ycell_idx=cell, magnif=1, shift_factor=sf, just_return=True)
    c2 = show_cell(img_siren, img_dim, xcell_idx=near_dist, ycell_idx=cell, magnif=1, shift_factor=sf, just_return=True)
    c3 = show_cell(img_siren, img_dim, xcell_idx=cell, ycell_idx=-near_dist, magnif=1, shift_factor=sf, just_return=True)
    c4 = show_cell(img_siren, img_dim, xcell_idx=cell, ycell_idx=near_dist, magnif=1, shift_factor=sf, just_return=True)
    losses = losses + [F.mse_loss(torch.squeeze(c1), ground_truth), F.mse_loss(torch.squeeze(c2), ground_truth),
                       F.mse_loss(torch.squeeze(c3), ground_truth), F.mse_loss(torch.squeeze(c4), ground_truth)]
  losses = torch.tensor(losses)
  return (torch.mean(losses), torch.std(losses))

def eval_model(img_siren, img_dim, ground_truth, modinfo, near_dist=1, far_dist=10, sf=1, verbose=True):
  regular_loss, regular_sd = eval_square(img_siren, img_dim, ground_truth, near_dist=0, sf=sf)
  near_loss, near_sd = eval_square(img_siren, img_dim, ground_truth, near_dist=near_dist, sf=sf)
  far_loss, far_sd = eval_square(img_siren, img_dim, ground_truth, near_dist=far_dist, sf=sf)
  if verbose:
    print(f"########## MODEL EVAL ##########\nModel: {modinfo}\nRegular loss: {regular_loss}, Dist {near_dist} tile loss: {near_loss},\
          Dist {far_dist} tile loss: {far_loss}\nRegular SD: {regular_sd}, Dist {near_dist} tile SD: {near_sd},\
          Dist {far_dist} tile SD: {far_sd}\n########## MODEL EVAL ##########")
  return [modinfo, regular_loss, near_loss, far_loss, regular_sd, near_sd, far_sd]



In [None]:
#@title Audio training utils


In [None]:
#@title Data loading

img_patterns = ['checker1.png','checker2.png','checker3.png','tartan1.png','tartan2.png','tartan3.png']
aud_patterns = [] #['bartok1.mp3'] #['hihat1.mp3','hihat3.mp3','hihat2-5.mp3','bartok1.mp3','bartok2.mp3','bartok1-5.mp3']
sidelen = 128 #256
#img_shift_factor = [1,1,1.2,1,1,1.2]
#aud_shift_factor = [1,1,1.2,1,1,1.33]

### Loading in IMAGES ###
img_pattern_fits = []
img_dataloaders = []
for pattern_file in img_patterns:
  img_pattern_fit = PatternFitting(pattern_file, sidelen, selected_channel=0)
  img_dataloader = DataLoader(img_pattern_fit, batch_size=1, pin_memory=True, num_workers=0)

  img_pattern_fits.append(img_pattern_fit)
  img_dataloaders.append(img_dataloader)

fkan_pattern_fits = []
fkan_dataloaders = []
for pattern_file in img_patterns:
  fkan_pattern_fit = PatternFitting(pattern_file, sidelen // 2, selected_channel=0)
  fkan_dataloader = DataLoader(fkan_pattern_fit, batch_size=1, pin_memory=True, num_workers=0)

  fkan_pattern_fits.append(fkan_pattern_fit)
  fkan_dataloaders.append(fkan_dataloader)

### Loading in AUDIO ###
aud_fits = []
aud_dataloaders = []
for aud_file in aud_patterns:
  aud_fitting = AudioFitting(aud_file, selected_channel=0)
  aud_dataloader = DataLoader(aud_fitting, batch_size=1, pin_memory=True, num_workers=0)

  aud_fits.append(aud_fitting)
  aud_dataloaders.append(aud_dataloader)


In [None]:
#@title Image training and eval loop
model_types = ['relu', 'tanh', 'rbf', 'sin', 'ffn', 'fkan']

final_losses = []
models = []
for mtype in model_types:
  for i, img in enumerate(img_dataloaders):
    torch.cuda.empty_cache()
    print(f'########## Starting to train model {mtype} on data {img_patterns[i]} ##########')
    indim = 256 if mtype=='ffn' else 2
    model = Flex_Model(indim, 1, hidden_dim=256, nhidden=3, freq=30, activation=mtype).to(device)

    img = fkan_dataloaders[i] if mtype=='fkan' else img
    sidelen = 64 if mtype=='fkan' else 128
    nsteps = 301 if i < 3 else 601
    train_siren(model, img, total_steps=nsteps, img_dim=sidelen, steps_til_summary=(nsteps - 1) // 3, lr=1e-4) #provides visuals

    models.append((f'{mtype} / {img_patterns[i]}', model.to('cpu')))
    #_, ground_truth = next(iter(img))
    #ground_truth = ground_truth.to(device)
    #losses = eval_model(model, sidelen, ground_truth, f'{mtype} / {img_patterns[i]}', near_dist=1, far_dist=2, verbose=True)
    #final_losses.append(losses)


In [None]:
for mtype in ['ffn']:
  for i, img in enumerate(img_dataloaders):
    torch.cuda.empty_cache()
    print(f'########## Starting to train model {mtype} on data {img_patterns[i]} ##########')
    indim = 256 if mtype=='ffn' else 2
    model = Flex_Model(indim, 1, hidden_dim=256, nhidden=3, freq=30, activation=mtype).to(device)

    img = fkan_dataloaders[i] if mtype=='fkan' else img
    sidelen = 64 if mtype=='fkan' else 128
    nsteps = 301 if i < 3 else 601
    train_siren(model, img, total_steps=nsteps, img_dim=sidelen, steps_til_summary=(nsteps - 1) // 3, lr=1e-4) #provides visuals

    models.append((f'{mtype} / {img_patterns[i]}', model.to('cpu')))

In [None]:
for mtype in ['fkan']:
  for i, img in enumerate(img_dataloaders):
    torch.cuda.empty_cache()
    print(f'########## Starting to train model {mtype} on data {img_patterns[i]} ##########')
    indim = 256 if mtype=='ffn' else 2
    model = Flex_Model(indim, 1, hidden_dim=256, nhidden=3, freq=30, activation=mtype).to(device)

    img = fkan_dataloaders[i] if mtype=='fkan' else img
    sidelen = 64 if mtype=='fkan' else 128
    nsteps = 301 if i < 3 else 601
    train_siren(model, img, total_steps=nsteps, img_dim=sidelen, steps_til_summary=(nsteps - 1) // 3, lr=1e-4) #provides visuals

    models.append((f'{mtype} / {img_patterns[i]}', model.to('cpu')))

In [None]:
#@title Audio training and eval loop