In [None]:
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt

from utils import ray_generator, plot_rays
from model import Sphere
from rendering import rendering
from loss import mse_loss, mse2psnr

In [None]:
# constants
HEIGHT = 400
WIDTH = 400
FOCUS = 1200

tn = 0.8
tf = 1.2

ORIGIN = [0.0, 0.0, -1.0]
RADIUS = [0.1]
learning_rate = 2e-1

color = [0.0, 1.0, 1.0]  # cyan
target_color = [1.0, 0.0, 1.0]  # purple

In [None]:
rays_origin, rays_direction = ray_generator(HEIGHT, WIDTH, FOCUS)

target_sphere = Sphere(
    torch.tensor(ORIGIN), torch.tensor(RADIUS), torch.tensor(target_color)
)
target_px_colors = rendering(
    target_sphere,
    torch.tensor(rays_origin),
    torch.tensor(rays_direction),
    0.8,
    1.2,
    white_background=False,
)

target_img = target_px_colors.reshape(HEIGHT, WIDTH, 3).cpu().numpy()
plt.title(f"target color: {target_color}")
plt.axis(False)
plt.imshow(target_img)
plt.savefig("images/target_sphere.jpg", bbox_inches="tight")

In [None]:
color_to_optimize = torch.tensor(color, requires_grad=True, dtype=torch.float32)
optimizer = torch.optim.SGD(params={color_to_optimize}, lr=learning_rate)

In [None]:
unoptim_sphere = Sphere(torch.tensor(ORIGIN), torch.tensor(RADIUS), color_to_optimize)
unoptim_px_colors = rendering(
    unoptim_sphere,
    torch.tensor(rays_origin),
    torch.tensor(rays_direction),
    0.8,
    1.2,
    white_background=False,
)
unoptim_img = unoptim_px_colors.reshape(HEIGHT, WIDTH, 3).data.cpu().numpy()
plt.title(f"unoptimized color: {color}")
plt.axis(False)
plt.imshow(unoptim_img)
plt.savefig(f"images/unoptimized_sphere.jpg", bbox_inches="tight")

In [None]:
concat_image = np.hstack((target_img, unoptim_img))
plt.axis(False)
plt.imshow(concat_image)

In [None]:
losses = []
for epoch in range(200):
    # using helper functions without print statements
    s = Sphere(torch.tensor(ORIGIN), torch.tensor(RADIUS), color_to_optimize)
    Ax = rendering(
        s,
        torch.tensor(rays_origin),
        torch.tensor(rays_direction),
        0.8,
        1.2,
        white_background=False,
    )

    loss = ((Ax - target_px_colors) ** 2).mean()
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        current_color = [round(color, 4) for color in color_to_optimize.data.tolist()]
        print(f"loss: {loss:.4f}")
        print(f"color: {current_color}")

        img = Ax.reshape(HEIGHT, WIDTH, 3).cpu().data.numpy()
        plt.title(f"{current_color}")
        plt.axis(False)
        plt.imshow(img)
        # plt.show()
        plt.savefig(f"sphere_img/sphere_epoch_{str(epoch)}.jpg", bbox_inches="tight")

In [None]:
plt.plot(losses)
plt.xlabel("epochs")
plt.ylabel("loss")
plt.savefig(f"images/sphere_loss.jpg", bbox_inches="tight")

In [None]:
import json

with open("outputs/sphere_losses.json", "r") as file:
    t_losses = json.dump(file)

In [None]:
losses = []
for epoch in range(200):
    # using helper functions without print statements
    s = Sphere(torch.tensor(ORIGIN), torch.tensor(RADIUS), color_to_optimize)
    Ax = rendering(
        s,
        torch.tensor(rays_origin),
        torch.tensor(rays_direction),
        0.8,
        1.2,
        white_background=False,
    )

    loss = ((Ax - target_px_colors) ** 2).mean()
    losses.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        current_color = [round(color, 4) for color in color_to_optimize.data.tolist()]
        print(f"loss: {loss:.4f}")
        print(f"color: {current_color}")

        img = Ax.reshape(HEIGHT, WIDTH, 3).cpu().data.numpy()
        fig, axs = plt.subplots(1, 2, figsize=(20, 9))
        axs[0].imshow(img)
        axs[0].set_title(f"{current_color}")
        axs[0].axis(False)

        axs[1].plot(t_losses)
        axs[1].scatter(epoch, t_losses[epoch], c="m", label="epoch no.")
        axs[1].set(xlabel="epoch", ylabel="loss")
        axs[1].legend()

        plt.tight_layout()

        plt.savefig(f"sphere_img/sphere_epoch_{str(epoch)}.jpg", bbox_inches="tight")
        plt.show()

In [None]:
from train import graphing_sphere_train, sphere_train

# graphing_sphere_train(color_to_optimize, rays_origin, rays_direction, target_px_colors, optimizer, save_dir="test_img")
sphere_train(
    color_to_optimize,
    rays_origin,
    rays_direction,
    target_px_colors,
    optimizer,
    save_dir="test_img",
)