forked from human-analysis/pytorchnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
checkpoints.py
31 lines (24 loc) · 938 Bytes
/
checkpoints.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# checkpoints.py
import os
import torch
class Checkpoints:
def __init__(self, args):
self.dir_save = args.save_dir
self.model_filename = args.resume
self.save_results = args.save_results
if self.save_results and not os.path.isdir(self.dir_save):
os.makedirs(self.dir_save)
def latest(self, name):
if name == 'resume':
return self.model_filename
def save(self, epoch, model, best):
if best is True:
torch.save(model.state_dict(),
'%s/model_epoch_%d.pth' % (self.dir_save, epoch))
def load(self, model, filename):
if os.path.isfile(filename):
print("=> loading checkpoint '{}'".format(filename))
state_dict = torch.load(filename)
model.load_state_dict(state_dict)
return model
raise (Exception("=> no checkpoint found at '{}'".format(filename)))