-
Notifications
You must be signed in to change notification settings - Fork 286
/
manual_poser.py
137 lines (111 loc) · 5.3 KB
/
manual_poser.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import sys
sys.path.append(os.getcwd())
from tkinter import Frame, Label, BOTH, Tk, LEFT, HORIZONTAL, Scale, Button, GROOVE, filedialog, PhotoImage, messagebox
import PIL.Image
import PIL.ImageTk
import numpy
import torch
from poser.morph_rotate_combine_poser import MorphRotateCombinePoser256Param6
from poser.poser import Poser
from tha.combiner import CombinerSpec
from tha.face_morpher import FaceMorpherSpec
from tha.two_algo_face_rotator import TwoAlgoFaceRotatorSpec
from util import extract_pytorch_image_from_filelike, rgba_to_numpy_image
class ManualPoserApp:
def __init__(self,
master,
poser: Poser,
torch_device: torch.device):
super().__init__()
self.master = master
self.poser = poser
self.torch_device = torch_device
self.master.title("Manual Poser")
source_image_frame = Frame(self.master, width=256, height=256)
source_image_frame.pack_propagate(0)
source_image_frame.pack(side=LEFT)
self.source_image_label = Label(source_image_frame, text="Nothing yet!")
self.source_image_label.pack(fill=BOTH, expand=True)
control_frame = Frame(self.master, borderwidth=2, relief=GROOVE)
control_frame.pack(side=LEFT, fill='y')
self.param_sliders = []
for param in self.poser.pose_parameters():
slider = Scale(control_frame,
from_=param.lower_bound,
to=param.upper_bound,
length=256,
resolution=0.001,
orient=HORIZONTAL)
slider.set(param.default_value)
slider.pack(fill='x')
self.param_sliders.append(slider)
label = Label(control_frame, text=param.display_name)
label.pack()
posed_image_frame = Frame(self.master, width=256, height=256)
posed_image_frame.pack_propagate(0)
posed_image_frame.pack(side=LEFT)
self.posed_image_label = Label(posed_image_frame, text="Nothing yet!")
self.posed_image_label.pack(fill=BOTH, expand=True)
self.load_source_image_button = Button(control_frame, text="Load Image ...", relief=GROOVE,
command=self.load_image)
self.load_source_image_button.pack(fill='x')
self.pose_size = len(self.poser.pose_parameters())
self.source_image = None
self.posed_image = None
self.current_pose = None
self.last_pose = None
self.needs_update = False
self.master.after(1000 // 30, self.update_image)
def load_image(self):
file_name = filedialog.askopenfilename(
filetypes=[("PNG", '*.png')],
initialdir="data/illust")
if len(file_name) > 0:
image = PhotoImage(file=file_name)
if image.width() != self.poser.image_size() or image.height() != self.poser.image_size():
message = "The loaded image has size %dx%d, but we require %dx%d." \
% (image.width(), image.height(), self.poser.image_size(), self.poser.image_size())
messagebox.showerror("Wrong image size!", message)
self.source_image_label.configure(image=image, text="")
self.source_image_label.image = image
self.source_image_label.pack()
self.source_image = extract_pytorch_image_from_filelike(file_name).to(self.torch_device).unsqueeze(dim=0)
self.needs_update = True
def update_pose(self):
self.current_pose = torch.zeros(self.pose_size, device=self.torch_device)
for i in range(self.pose_size):
self.current_pose[i] = self.param_sliders[i].get()
self.current_pose = self.current_pose.unsqueeze(dim=0)
def update_image(self):
self.update_pose()
if (not self.needs_update) and self.last_pose is not None and (
(self.last_pose - self.current_pose).abs().sum().item() < 1e-5):
self.master.after(1000 // 30, self.update_image)
return
if self.source_image is None:
self.master.after(1000 // 30, self.update_image)
return
self.last_pose = self.current_pose
posed_image = self.poser.pose(self.source_image, self.current_pose).detach().cpu()
numpy_image = rgba_to_numpy_image(posed_image[0])
pil_image = PIL.Image.fromarray(numpy.uint8(numpy.rint(numpy_image * 255.0)), mode='RGBA')
photo_image = PIL.ImageTk.PhotoImage(image=pil_image)
self.posed_image_label.configure(image=photo_image, text="")
self.posed_image_label.image = photo_image
self.posed_image_label.pack()
self.needs_update = False
self.master.after(1000 // 30, self.update_image)
if __name__ == "__main__":
cuda = torch.device('cuda')
poser = MorphRotateCombinePoser256Param6(
morph_module_spec=FaceMorpherSpec(),
morph_module_file_name="data/face_morpher.pt",
rotate_module_spec=TwoAlgoFaceRotatorSpec(),
rotate_module_file_name="data/two_algo_face_rotator.pt",
combine_module_spec=CombinerSpec(),
combine_module_file_name="data/combiner.pt",
device=cuda)
root = Tk()
app = ManualPoserApp(master=root, poser=poser, torch_device=cuda)
root.mainloop()