-
Notifications
You must be signed in to change notification settings - Fork 6
/
test_learning.py
37 lines (31 loc) · 1.51 KB
/
test_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import unittest
import torch
from sf_vae import Learning
from sf_vae import VAE
class TestAudioTools(unittest.TestCase):
def test_get_s(self):
vae = VAE()
learn = Learning(config_factor=dict(factor="f1",
path_trajectory=r"D:\These\data\Audio\Phonemes\vowel\synthesis_soundgen\formant_1\f2-1600",
dim=3), model=vae)
# learn.get_trajectory()
learn.get_s()
self.assertEqual(learn.s.shape, (16, 16))
def test_get_u(self):
vae = VAE()
checkpoint = torch.load(r"checkpoints\vae_trained")
vae.load_state_dict(checkpoint['model_state_dict'])
learn = Learning(config_factor=dict(factor="f1",
path_trajectory=r"D:\These\data\Audio\Phonemes\vowel\synthesis_soundgen\formant_1\f2-1600",
dim=3), model=vae, path_save=r"checkpoints\pca-regression")
# learn.get_u()
def test_learning(self):
vae = VAE()
checkpoint = torch.load(r"checkpoints\vae_trained")
vae.load_state_dict(checkpoint['model_state_dict'])
learn = Learning(config_factor=dict(factor="f2",
path_trajectory=r"D:\These\data\Audio\Phonemes\vowel\synthesis_soundgen\formant_2\f1-800",
dim=3), model=vae, path_save=r"checkpoints\pca-regression")
learn()
if __name__ == '__main__':
unittest.main()