In [None]:
'''
Test hdr_model
'''

from hdr_model import PLFourierNet, HDRDataset
from torch.utils.data import DataLoader
import torch
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt

ckpt_path = "hdr_model/last.ckpt"
model = PLFourierNet(pos_embed=False)
for k,v in torch.load(ckpt_path, map_location='cpu')['state_dict'].items():
    model.state_dict()[k].copy_(v)
dataset = HDRDataset(side_length=256)
loader = DataLoader(dataset, batch_size=1, pin_memory=True)
ys_pred = []
ys = []
zs = np.arange(0.0, 0.05, 0.005)
for i, batch in enumerate(loader):
    x, y, _ = batch
    z = torch.tensor([[zs[i]]]).float().to(model.device)
    print(z)
    x = x.float().to(model.device)
    y = y.float().to(model.device)
    z = z.float().to(model.device)
    _ , y_pred = model(x, z)
    loss = ((y_pred - y) ** 2).mean()
    psnr = -10*torch.log10(loss)
    print(psnr)
    ys_pred.append(y_pred[0].detach().cpu().numpy().reshape(256,256))
    ys.append(y[0].detach().cpu().numpy().reshape(256,256))

length = len(dataset)
fig1, axs1 = plt.subplots(1, length, figsize=(length * 4, 4))
fig2, axs2 = plt.subplots(1, length, figsize=(length * 4, 4))

for i, (y_, y_pred_) in enumerate(zip(ys, ys_pred)):
    y_ = np.array(Image.open(dataset.images[i]).convert("L"))
    axs1[i].imshow(y_ , cmap='gray')
    axs1[i].axis("off")
    axs2[i].imshow(y_pred_,  cmap='gray')
    axs2[i].axis("off")



In [None]:
'''
test_hdr_model with slide bar
'''

import torch
import cv2
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from hdr_model_mul import PLFourierNet

def show_image(z):
    assert z>=0 and z<=5
    ckpt_path = "hdr_model/last.ckpt"
    model = PLFourierNet(pos_embed=False)
    for k,v in torch.load(ckpt_path, map_location='cpu')['state_dict'].items():
        model.state_dict()[k].copy_(v)
    side_length = 256
    x = torch.stack(
            torch.meshgrid(
                [
                    torch.linspace(-1.0, 1.0, side_length),
                    torch.linspace(-1.0, 1.0, side_length),
                ]
            ),
            dim=-1,
        ).view(-1, 2).to(model.device)
    z = torch.tensor([[z]]).float().to(model.device)
    _ , y_pred = model(x, z)
    y_pred = y_pred[0].detach().cpu().numpy().reshape(256,256)
    plt.imshow(y_pred)
    plt.show()

contrast_slider = widgets.FloatSlider(value=0.001, min=0.0, max=0.06, step=0.001, description='latent')
# display(contrast_slider)
widgets.interactive(show_image, z=contrast_slider)

In [None]:
'''
test hdr_model with concatentation of [x,y,z]
'''

import torch
import cv2
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from hdr_model_xyz import PLFourierNet

def show_image():
    ckpt_path = "hdr_xyz/last-v1.ckpt"
    model = PLFourierNet()
    for k,v in torch.load(ckpt_path, map_location='cpu')['state_dict'].items():
        model.state_dict()[k].copy_(v)
    side_length = 256
    x = torch.stack(
            torch.meshgrid(
                [
                    torch.linspace(-1.0, 1.0, side_length),
                    torch.linspace(-1.0, 1.0, side_length),
                    torch.linspace(0.01, 0.02, 5),
                ]
            ),
            dim=-1,
        ).view(-1,5,3).transpose(0,1)
    y_pred = model(x)
    print(y_pred.shape)
    y_pred = (y_pred-y_pred.min())/(y_pred.max()-y_pred.min())
    for y in y_pred:
        y = y.detach().cpu().numpy().reshape(side_length, side_length,3)
        plt.imshow(y)
        plt.show()

show_image()

In [None]:
'''
test hdr_model with relu activation
'''

import torch
import cv2
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from hdr_model_relu import PLFourierNet

def show_image():
    ckpt_path = "hdr_relu/last-v1.ckpt"
    model = PLFourierNet()
    for k,v in torch.load(ckpt_path, map_location='cpu')['state_dict'].items():
        model.state_dict()[k].copy_(v)
    side_length = 256
    x = torch.stack(
            torch.meshgrid(
                [
                    torch.linspace(-1.0, 1.0, side_length),
                    torch.linspace(-1.0, 1.0, side_length),
                    torch.linspace(0.00, 0.04, 5),
                ]
            ),
            dim=-1,
        ).view(-1,5,3).transpose(0,1)
    y_pred = model(x)
    print(y_pred.shape)
    y_pred = (y_pred-y_pred.min())/(y_pred.max()-y_pred.min())
    for y in y_pred:
        y = y.detach().cpu().numpy().reshape(side_length, side_length,3)
        plt.imshow(y)
        plt.show()

show_image()

In [None]:
'''
test motion model xyt
'''

import torch
import cv2
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from motion import PLFourierNet
import numpy as np

def show_image():
    ckpt_path = "motion_xyt/last-v2.ckpt"
    model = PLFourierNet()
    model = model.load_from_checkpoint(ckpt_path)
    # for k,v in torch.load(ckpt_path, map_location='cpu')['state_dict'].items():
    #     model.state_dict()[k].copy_(v)
    side_length = 256
    sequence_range=[5,20]
    skip = 4
    frames = len(np.arange(sequence_range[0], sequence_range[1], skip))

    x = torch.stack(
        torch.meshgrid(
            [
                torch.linspace(-1.0, 1.0, side_length),
                torch.linspace(-1.0, 1.0, side_length),
                torch.linspace(-0.01, 0.00, frames),
            ]
        ),
        dim=-1,
    ).view(-1,frames,3).transpose(0,1)
    y_pred = model(x)
    print(y_pred.shape)
    y_pred = (y_pred-y_pred.min())/(y_pred.max()-y_pred.min())
    for y in y_pred:
        y = y.detach().cpu().numpy().reshape(side_length, side_length,3)
        plt.imshow(y)
        plt.axis('off')
        plt.show()

show_image()

In [None]:
'''
test motion model phase
'''

import torch
import cv2
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from motion_phase import PLFourierNet
import numpy as np

def show_image(z):
    ckpt_path = "motion_phase/last.ckpt"
    model = PLFourierNet(pos_embed = False)
    for k,v in torch.load(ckpt_path, map_location='cpu')['state_dict'].items():
        model.state_dict()[k].copy_(v)
    side_length = 256

    x = torch.stack(
            torch.meshgrid(
                [
                    torch.linspace(-1.0, 1.0, side_length),
                    torch.linspace(-1.0, 1.0, side_length),
                ]
            ),
            dim=-1,
        ).view(-1,2)
    z = torch.tensor([[z]]).float().to(model.device)
    main_representation, y_pred = model(x, z)
   
    y_pred = (y_pred-y_pred.min())/(y_pred.max()-y_pred.min())
    main_representation = (main_representation-main_representation.min())/(main_representation.max()-main_representation.min())

    for y in y_pred:
        y = y.detach().cpu().numpy().reshape(side_length, side_length,3)
        plt.imshow(y)
        plt.axis('off')
        plt.show()
    
    main_representation = main_representation.detach().cpu().numpy().reshape(side_length, side_length,3)
    plt.imshow(main_representation)
    plt.axis('off')
    plt.show()


contrast_slider = widgets.FloatSlider(value=0.000, min=0.0, max=0.1, step=0.005, description='latent')
# display(contrast_slider)
widgets.interactive(show_image, z=contrast_slider)

In [None]:
'''
test motion model phase
'''

import torch
import cv2
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from motion_phase import PLFourierNet
import numpy as np

def show_image(z):
    ckpt_path = "motion_phase/last-v1.ckpt"
    model = PLFourierNet(pos_embed = False)
    for k,v in torch.load(ckpt_path, map_location='cpu')['state_dict'].items():
        model.state_dict()[k].copy_(v)
    side_length = 256

    x = torch.stack(
            torch.meshgrid(
                [
                    torch.linspace(-1.0, 1.0, side_length),
                    torch.linspace(-1.0, 1.0, side_length),
                ]
            ),
            dim=-1,
        ).view(-1,2)
    z = torch.tensor([[z]]).float().to(model.device)
    main_representation, y_pred = model(x, z)
   
    y_pred = (y_pred-y_pred.min())/(y_pred.max()-y_pred.min())
    main_representation = (main_representation-main_representation.min())/(main_representation.max()-main_representation.min())

    for y in y_pred:
        y = y.detach().cpu().numpy().reshape(side_length, side_length,3)
        plt.imshow(y)
        plt.axis('off')
        plt.show()
    
    main_representation = main_representation.detach().cpu().numpy().reshape(side_length, side_length,3)
    plt.imshow(main_representation)
    plt.axis('off')
    plt.show()


contrast_slider = widgets.FloatSlider(value=0.000, min=0.0, max=0.05, step=0.005, description='latent')
# display(contrast_slider)
widgets.interactive(show_image, z=contrast_slider)

In [1]:
'''
test motion model phase
'''

import torch
import cv2
from IPython.display import display
import ipywidgets as widgets
import matplotlib.pyplot as plt
from motion_modulate_posembed import PLFourierNet
import numpy as np

def show_image(z):
    ckpt_path = "/home/xxy/Documents/code/Fourier-Manifold/motion_phase_pos_embed/last-v12.ckpt"
    model = PLFourierNet()
    for k,v in torch.load(ckpt_path, map_location='cpu')['state_dict'].items():
        model.state_dict()[k].copy_(v)
    
    x = torch.stack(
            torch.meshgrid(
                [
                    torch.linspace(-1.0, 1.0, side_length),
                    torch.linspace(-1.0, 1.0, side_length),
                ]
            ),
            dim=-1,
        ).view(-1,2)
    z1 = torch.sin(torch.linspace(-1.0, 1.0, side_length)*(2**(10))*np.pi+z/20*2*np.pi).view(-1,1)
    z2 = torch.cos(torch.linspace(-1.0, 1.0, side_length)*(2**(10))*np.pi+z/20*2*np.pi).view(1,-1)
    z = z1+z2
    z = 0.01*z.view(-1,1)
    x = x.unsqueeze(0)
    z = z.unsqueeze(0)
    main_representation, y_pred = model(x, z)
    print(y_pred.shape)
    y_pred = (y_pred-y_pred.min())/(y_pred.max()-y_pred.min())
    main_representation = (main_representation-main_representation.min())/(main_representation.max()-main_representation.min())

    for y in y_pred:
        y = y.detach().cpu().numpy().reshape(side_length, side_length,3)
        plt.imshow(y)
        plt.axis('off')
        plt.show()
    
    main_representation = main_representation.detach().cpu().numpy().reshape(side_length, side_length,3)
    plt.imshow(main_representation)
    plt.axis('off')
    plt.show()

side_length = 64
sequence_range=[15,175]
skip = 10
frames = len(np.arange(sequence_range[0], sequence_range[1], skip))
contrast_slider = widgets.FloatSlider(value=0.000, min=0.0, max=frames-1, step=0.3, description='latent')
# display(contrast_slider)
widgets.interactive(show_image, z=contrast_slider)

interactive(children=(FloatSlider(value=0.0, description='latent', max=15.0, step=0.3), Output()), _dom_classe…

In [None]:
from motion_modulate_posembed import HDRDataModule
kwargs = {
        'project_name': 'motion_phase_pos_embed',
        'batch_size': 10,
        'side_length': 64,
        'pos_embed': False,
        'sequence_range': [15,75],
        'skip': 10,
        'motion_path': './motion_data/man'
    }
frames = len(np.arange(sequence_range[0], sequence_range[1], skip))
loader = HDRDataModule(batch_size = kwargs['batch_size'], side_length=kwargs['side_length'], motion_path = kwargs['motion_path'], sequence_range=kwargs['sequence_range'], skip = kwargs['skip'])
loader.setup()
for d in loader.train_dataloader():
    x, y_pred = d
    y_pred = (y_pred-y_pred.min())/(y_pred.max()-y_pred.min())
    fig, ax = plt.subplots(1,frames, figsize=(20,20))
    for i,y in enumerate(y_pred):
        y = y.detach().cpu().numpy().reshape(side_length, side_length,3)
        ax[i].imshow(y)
        plt.axis('off')
    plt.show()