-
Notifications
You must be signed in to change notification settings - Fork 0
/
sphere_color.py
49 lines (41 loc) · 1.03 KB
/
sphere_color.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from rendering import rendering
from utils import ray_generator
from model import Sphere
from config import (
ORIGIN,
RADIUS,
HEIGHT,
WIDTH,
SAVE_DIR,
FOCUS,
target_color,
color,
learning_rate,
)
from sphere_train import sphere_train
def main():
rays_origin, rays_direction = ray_generator(HEIGHT, WIDTH, FOCUS)
target_sphere = Sphere(
torch.tensor(ORIGIN), torch.tensor(RADIUS), torch.tensor(target_color)
)
target_px_colors = rendering(
target_sphere,
torch.tensor(rays_origin),
torch.tensor(rays_direction),
0.8,
1.2,
white_background=False,
)
color_to_optimize = torch.tensor(color, requires_grad=True, dtype=torch.float32)
optimizer = torch.optim.SGD(params={color_to_optimize}, lr=learning_rate)
loss = sphere_train(
color_to_optimize,
rays_origin,
rays_direction,
target_px_colors,
optimizer,
save_dir="test_img",
)
if __name__ == "__main__":
main()