In [1]:
import random

import torch
import torch.nn as nn
from cmws import util
import torch
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    PointLights,
    DirectionalLights,
    Materials,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    HardPhongShader,
    TexturesUV,
    TexturesVertex,
    BlendParams,
    softmax_rgb_blend
)
from pytorch3d.structures.meshes import (
    Meshes,
    join_meshes_as_batch,
    join_meshes_as_scene,
)
import numpy as np
from cmws.examples.stacking_3d import data, render
from cmws import util
import matplotlib.pyplot as plt

device = 'cpu'

In [2]:
def plot(target_img, img, losses, iteration, path):
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(target_img.cpu().detach())
    axs[0].set_title("Target")
    axs[1].imshow(img.cpu().detach())
    axs[1].set_title(f"Rendered (iter {iteration})")
    axs[2].plot(losses)
    axs[2].set_xlabel("Iteration")
    axs[2].set_ylabel("Loss")
    for ax in axs[:2]:
        ax.set_xticks([])
        ax.set_yticks([])
    util.save_fig(fig, path)


def get_position(raw_position, true_position):
    if raw_position.ndim == 0:
        position = true_position.clone().detach()
        # NOTE: change index to only learn x [0], y [1], z [2]
        position[0] = raw_position#.sigmoid() * 1.6 - 0.8
        return position
    else:
        pass

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

flip_loss = True

im_size = 256

true_size = torch.tensor(0.7, device=device)
# true_position = torch.tensor([0.0, 0.0, -1.0], device=device)
true_position = torch.tensor([0.0, -1.0, 0.0], device=device)
true_color = torch.tensor([1.0, 0, 0], device=device)
target_img = render.render_cube(true_size, true_color, true_position, im_size)

# fixed size + color 
raw_size = torch.tensor(0.7, device=device,requires_grad=False)
raw_color = torch.tensor([1.0, 0, 0], device=device,requires_grad=False)
# learn size + color
# raw_size = torch.tensor(0.0, device=device, requires_grad=True)
# raw_color = torch.randn(3, device=device, requires_grad=True)

# learn only a single dimension (x)
raw_x_position = torch.tensor(0.2, device=device,requires_grad=True)
# learn from random start
#raw_x_position = torch.zeros((), device=device, requires_grad=True)

optimizer = torch.optim.Adam([raw_size, raw_x_position, raw_color], lr=5e-2)
num_iterations = 100
losses = []

for i in range(num_iterations):
    optimizer.zero_grad()
    img = render.render_cube(
        raw_size,
        raw_color,
        get_position(raw_x_position, true_position),
        im_size,
    )
    
    # use code block if learning color + size from random init
    #     img = render.render_cube(
    #         raw_size.exp(),
    #         torch.softmax(raw_color, 0),
    #         get_position(raw_x_position, true_position),
    #         im_size,
    #     )
    
    
    if flip_loss: loss = -(img - target_img).pow(2).sum() # hack
    else: loss = (img - target_img).pow(2).sum() # use original loss
    loss.backward()
    optimizer.step()

    losses.append(loss.item())
    plot(target_img, img, losses, i, f"save/cube/{i}.png")

    print(f"Iter. {i} | Loss {losses[-1]}")

util.make_gif(
    [f"save/cube/{i}.png" for i in range(num_iterations)], "save/cube/reconstruction.gif", 10
)

16:42:44 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/0.png
Iter. 0 | Loss 3445.971923828125
16:42:45 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/1.png
Iter. 1 | Loss 5085.2734375
16:42:47 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/2.png
Iter. 2 | Loss 6506.6943359375
16:42:49 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/3.png
Iter. 3 | Loss 7722.6533203125
16:42:51 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/4.png
Iter. 4 | Loss 8761.4697265625
16:42:52 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/5.png
Iter. 5 | Loss 9627.1650390625
16:42:54 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/6.png
Iter. 6 | Loss 10428.8291015625
16:42:56 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/7.png
Iter. 7 | Loss 11049.3994140625
16:42:5

16:50:25 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/66.png
Iter. 66 | Loss 42847.58203125
16:50:41 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/67.png
Iter. 67 | Loss 42043.34765625
16:50:56 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/68.png
Iter. 68 | Loss 41517.2421875
16:51:12 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/69.png
Iter. 69 | Loss 41261.31640625
16:51:29 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/70.png
Iter. 70 | Loss 42080.59765625
16:51:45 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/71.png
Iter. 71 | Loss 42371.3125
16:52:01 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/72.png
Iter. 72 | Loss 41785.37890625
16:52:14 | /om/user/katiemc/continuous_mws/cmws/util.py:184 | INFO: Saved to save/cube/73.png
Iter. 73 | Loss 41242.12109375
16:52

  8%|▊         | 8/100 [00:00<00:01, 71.29it/s]

Iter. 99 | Loss 54003.15234375


100%|██████████| 100/100 [00:01<00:00, 70.83it/s]


16:57:37 | /om/user/katiemc/continuous_mws/cmws/util.py:194 | INFO: Saved to save/cube/reconstruction.gif


In [4]:
get_position(raw_x_position, true_position)

In [5]:
get_position(raw_x_position, true_position)

In [6]:
torch.softmax(raw_color, 0)

tensor([0.5761, 0.2119, 0.2119])

In [7]:
raw_color

tensor([1., 0., 0.])

In [8]:
raw_x_position = torch.zeros((), device=device, requires_grad=True)

In [9]:
raw_x_position


tensor(0., requires_grad=True)

In [10]:
raw_x_position = torch.zeros((), device=device, requires_grad=True)
raw_x_position.sigmoid() * 1.6 - 0.8

tensor(0., grad_fn=<SubBackward0>)

In [11]:
raw_x_position.ndim

0

In [12]:
position = true_position.clone().detach()
position[0] = raw_x_position.sigmoid() * 1.6 - 0.8

In [13]:
position

tensor([ 0., -1.,  0.], grad_fn=<CopySlices>)

In [14]:
raw_x_position

tensor(0., requires_grad=True)