/
train.py
109 lines (87 loc) · 4.5 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
from model import ConvolutionalNeuralNetwork
from utils import update_parameters, get_accuracy
def train(args):
dataset = omniglot(args.folder, shots=args.num_shots, ways=args.num_ways,
shuffle=True, test_shots=15, meta_train=True, download=args.download)
dataloader = BatchMetaDataLoader(dataset, batch_size=args.batch_size,
shuffle=True, num_workers=args.num_workers)
model = ConvolutionalNeuralNetwork(1, args.num_ways,
hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
with tqdm(dataloader, total=args.num_batches) as pbar:
for batch_idx, batch in enumerate(pbar):
model.zero_grad()
train_inputs, train_targets = batch['train']
train_inputs = train_inputs.to(device=args.device)
train_targets = train_targets.to(device=args.device)
test_inputs, test_targets = batch['test']
test_inputs = test_inputs.to(device=args.device)
test_targets = test_targets.to(device=args.device)
outer_loss = torch.tensor(0., device=args.device)
accuracy = torch.tensor(0., device=args.device)
for task_idx, (train_input, train_target, test_input,
test_target) in enumerate(zip(train_inputs, train_targets,
test_inputs, test_targets)):
train_logit = model(train_input)
inner_loss = F.cross_entropy(train_logit, train_target)
model.zero_grad()
params = update_parameters(model, inner_loss,
step_size=args.step_size, first_order=args.first_order)
test_logit = model(test_input, params=params)
outer_loss += F.cross_entropy(test_logit, test_target)
with torch.no_grad():
accuracy += get_accuracy(test_logit, test_target)
outer_loss.div_(args.batch_size)
accuracy.div_(args.batch_size)
outer_loss.backward()
meta_optimizer.step()
pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
if batch_idx >= args.num_batches:
break
# Save model
if args.output_folder is not None:
filename = os.path.join(args.output_folder, 'maml_omniglot_'
'{0}shot_{1}way.pt'.format(args.num_shots, args.num_ways))
with open(filename, 'wb') as f:
state_dict = model.state_dict()
torch.save(state_dict, f)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser('Model-Agnostic Meta-Learning (MAML)')
parser.add_argument('folder', type=str,
help='Path to the folder the data is downloaded to.')
parser.add_argument('--num-shots', type=int, default=5,
help='Number of examples per class (k in "k-shot", default: 5).')
parser.add_argument('--num-ways', type=int, default=5,
help='Number of classes per task (N in "N-way", default: 5).')
parser.add_argument('--first-order', action='store_true',
help='Use the first-order approximation of MAML.')
parser.add_argument('--step-size', type=float, default=0.4,
help='Step-size for the gradient step for adaptation (default: 0.4).')
parser.add_argument('--hidden-size', type=int, default=64,
help='Number of channels for each convolutional layer (default: 64).')
parser.add_argument('--output-folder', type=str, default=None,
help='Path to the output folder for saving the model (optional).')
parser.add_argument('--batch-size', type=int, default=16,
help='Number of tasks in a mini-batch of tasks (default: 16).')
parser.add_argument('--num-batches', type=int, default=100,
help='Number of batches the model is trained over (default: 100).')
parser.add_argument('--num-workers', type=int, default=1,
help='Number of workers for data loading (default: 1).')
parser.add_argument('--download', action='store_true',
help='Download the Omniglot dataset in the data folder.')
parser.add_argument('--use-cuda', action='store_true',
help='Use CUDA if available.')
args = parser.parse_args()
args.device = torch.device('cuda' if args.use_cuda
and torch.cuda.is_available() else 'cpu')
train(args)