-
Notifications
You must be signed in to change notification settings - Fork 0
/
vae_model_frame.py
40 lines (30 loc) · 974 Bytes
/
vae_model_frame.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch.nn as nn
from abc import abstractmethod
from typing import List, Callable, Union, Any, TypeVar, Tuple
# from torch import tensor as Tensor
Tensor = TypeVar('torch.tensor') # Type 지정
class VAE_Frame(nn.Module):
def __init__(self):
super().__init__()
def encoder(self, input : Tensor) -> List[Tensor]:
raise NotImplementedError
def decoder(self, input:Tensor):
raise NotImplementedError
def generate(self, x: Tensor,**kwargs) -> Tensor:
"""
:param x: Image
:return: Image
"""
raise NotImplementedError
def sample(self, samples_num: int, **kwargs) -> Tensor :
"""
:param samples_num:
:return: Sampled images
"""
raise NotImplementedError
@abstractmethod # 이 Method는 꼭 구현해야함을 강제.
def forward(self):
pass
@abstractmethod
def loss_function(self, *inputs: Any, **kwargs):
pass