-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_bn_mnist2mnist.py
98 lines (91 loc) · 3 KB
/
train_bn_mnist2mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from comet_ml import Experiment, OfflineExperiment
import torch.utils.data as data
from torchvision import transforms
from src.trainers.incrementals.components.adversarials import (
IADATrainerComponent,
IASMTrainerComponent,
IASVTrainerComponent
)
from src.trainers import (
DATrainerComponent,
SMTrainerComponent,
CDATrainerComponent,
CSMTrainerComponent,
AdaINDATrainerComponent,
)
from src.trainers.incrementals.components.conditional_adversarial import (
ICADATrainerComponent,
ICASMTrainerComponent,
)
from src.trainers.incrementals.mnist import IncrementalMnistTrainer
from src.trainers.incrementals.cityscapes import IncrementalCityscapesTrainer
from src.datasets import (
IDAMNIST,
IDASVHN
)
from src.models.incrementals.components import (
DANNClassifier,
DANNEncoder,
AdaBNClassifier,
DANNSourceGenerator,
CDANNSourceGenerator,
DANNSourceDiscriminator,
DANNDomainDiscriminator,
DANNConvSourceDiscriminator,
Classifier,
DomainDiscriminator,
SourceGenerator,
SourceDiscriminator,
Encoder,
VGGEncoder,
Decoder,
)
from src.models.incrementals.adversarial import IncrementalAdversarialModel
from src.models.incrementals.conditional_adversarial import IncrementalConditionalAdversarialModel
from src.analyzers import TargetImageSaver, TargetFeatureVisualizer
experiment = Experiment(api_key="laHAJPKUmrD2TV2dIaOWFYGkQ",
project_name="iada", workspace="yamad07")
source_transform = transforms.Compose([
transforms.Resize((28, 28)),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.25, )),
])
target_transform = transforms.Compose([
transforms.Resize((22, 28)),
transforms.Pad((0, 3, 0, 3)),
transforms.ToTensor(),
transforms.Normalize((0.5, ), (0.25, )),
])
mnist_dataset = IDAMNIST(
root='./data/',
download=True,
)
train_data_loader = data.DataLoader(mnist_dataset, batch_size=256, shuffle=True)
validate_data_loader = data.DataLoader(
mnist_dataset, batch_size=256, shuffle=True)
model = IncrementalAdversarialModel(
classifier=DANNClassifier(576),
domain_discriminator=DANNDomainDiscriminator(576),
source_generator=DANNSourceGenerator(z_dim=128, num_features=576),
source_discriminator=DANNSourceDiscriminator(576),
source_encoder=DANNEncoder(),
target_encoder=DANNEncoder(),
n_features=576
)
trainer = IncrementalMnistTrainer(
model=model,
trainer_component_list=[
SMTrainerComponent(),
AdaINDATrainerComponent(),
],
epoch_component_list=[100, 1],
experiment=experiment,
train_data_loader=train_data_loader,
valid_data_loader=validate_data_loader,
cuda_id=1,
size_list=[0, 2, 3, 4, 5, 6, 7, 8],
analyzer_list=[
TargetImageSaver(),
]
)
trainer.train()