-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_maker.py
66 lines (60 loc) · 2.51 KB
/
model_maker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#from torchvision.models import resnet18
import torch.nn as nn
from models.mlp import MLP
from models.mlp_bn import MLPBN
from models.resnet18 import resnet18
from models.resnet18_mixup import resnet18_mixup
from models.resnet18_dropblock import resnet18_dropblock
from models.vgg13 import vgg13
from models.vgg13_mixup import vgg13_mixup
from models.vgg13_dropblock import vgg13_dropblock
class ModelMaker:
def __init__(self, architecture, dataset_type, option=None, *args, **kwargs):
self._architecture = architecture
self._dataset_type = dataset_type
if self._is_mlp():
self._set_params(**kwargs)
self._make_model(option, *args, **kwargs)
def _set_params(self, num_classes, *args, **kwargs):
self._params = {
'num_classes': num_classes,
'input_size': input_size,
'num_hidden_layers': 2,
'num_units': 1000,
}
def _is_mlp(self):
return self._architecture.startswith('mlp')
def _make_model(self, option=None, *args, **kwargs):
if self._architecture == 'mlp':
print('Use Multilayer Perceptron')
self._model = MLP(**self._params)
elif self._architecture == 'mlp-bn':
print('Use MLP with batch normalization')
self._model = MLPBN(**self._params)
elif self._architecture == 'resnet18':
print('Use resnet18')
if self._dataset_type in ['cifar10', 'svhn']:
if option == None:
self._model = resnet18() # This is not from torchvision
elif option == 'mixup':
self._model = resnet18_mixup(**kwargs)
elif option == 'dropblock':
self._model = resnet18_dropblock(**kwargs)
else:
raise ValueError('Resnet for the dataset is not supported yet')
elif self._architecture == 'vgg13':
print('Use VGG13')
if self._dataset_type in ['cifar10', 'svhn']:
if option == None:
self._model = vgg13() # This is not from torchvision
elif option == 'mixup':
self._model = vgg13_mixup(**kwargs)
elif option == 'dropblock':
self._model = vgg13_dropblock(**kwargs)
else:
raise ValueError('Invalid model')
@property
def model(self):
if not self._model:
raise NameError('No model exists. You have to make model first')
return self._model