-
Notifications
You must be signed in to change notification settings - Fork 1
/
trainer.py
85 lines (69 loc) · 3.11 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import sys
import logging
import copy
import torch
from utils import factory
from utils.data_manager import DataManager
from utils.toolkit import count_parameters
import argparse
def train(args):
seed_list = copy.deepcopy(args['seed'])
device = copy.deepcopy(args['device'])
torch.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
args['seed'] = seed_list
args['device'] = device
_train(args)
def _train(args):
longtail='Longtail' if args['longtail']==1 else 'Normal'
logfilename = '{}_{}_{}_{}_{}_{}_{}'.format(args['seed'], args['model_name'], args['convnet_type'],
args['dataset'], args['init_cls'], args['increment'],longtail)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s [%(filename)s] => %(message)s',
handlers=[
logging.FileHandler(filename=logfilename + '.log'),
logging.StreamHandler(sys.stdout)
]
)
logging.info(args)
logging.info('Seed: {}'.format(args['seed']))
logging.info('Model: {}'.format(args['model_name']))
logging.info('Convnet: {}'.format(args['convnet_type']))
logging.info('Dataset: {}'.format(args['dataset']))
_set_device(args)
data_manager = DataManager(args['dataset'], args['shuffle'], args['seed'], args['init_cls'], args['increment'],args['longtail'])
model = factory.get_model(args['model_name'], args)
cnn_curve, nme_curve = {'top1': [], 'top5': []}, {'top1': [], 'top5': []}
for task in range(data_manager.nb_tasks):
logging.info('All params: {}'.format(count_parameters(model._network)))
logging.info('Trainable params: {}'.format(count_parameters(model._network, True)))
model.incremental_train(data_manager)
cnn_accy, nme_accy = model.eval_task()
model.after_task()
if nme_accy is not None:
logging.info('CNN: {}'.format(cnn_accy['grouped']))
logging.info('NME: {}'.format(nme_accy['grouped']))
cnn_curve['top1'].append(cnn_accy['top1'])
cnn_curve['top5'].append(cnn_accy['top5'])
nme_curve['top1'].append(nme_accy['top1'])
nme_curve['top5'].append(nme_accy['top5'])
#logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
#logging.info('CNN top5 curve: {}'.format(cnn_curve['top5']))
logging.info('NCM top1 curve: {}'.format(nme_curve['top1']))
logging.info('NCM top5 curve: {}\n'.format(nme_curve['top5']))
else:
logging.info('No NME accuracy.')
logging.info('CNN: {}'.format(cnn_accy['grouped']))
cnn_curve['top1'].append(cnn_accy['top1'])
cnn_curve['top5'].append(cnn_accy['top5'])
logging.info('CNN top1 curve: {}'.format(cnn_curve['top1']))
logging.info('CNN top5 curve: {}\n'.format(cnn_curve['top5']))
def _set_device(args):
device_type = args['device']
if device_type == -1:
device = torch.device('cpu')
else:
device = torch.device('cuda:{}'.format(device_type))
args['device'] = device