-
Notifications
You must be signed in to change notification settings - Fork 358
/
training.py
158 lines (113 loc) · 5.64 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import time
import torch
import random
import torch.nn as nn
import numpy as np
from transformers import AdamW, get_linear_schedule_with_warmup
from colbert.infra import ColBERTConfig
from colbert.training.rerank_batcher import RerankBatcher
from colbert.utils.amp import MixedPrecisionManager
from colbert.training.lazy_batcher import LazyBatcher
from colbert.parameters import DEVICE
from colbert.modeling.colbert import ColBERT
from colbert.modeling.reranker.electra import ElectraReranker
from colbert.utils.utils import print_message
from colbert.training.utils import print_progress, manage_checkpoints
def train(config: ColBERTConfig, triples, queries=None, collection=None):
config.checkpoint = config.checkpoint or 'bert-base-uncased'
if config.rank < 1:
config.help()
random.seed(12345)
np.random.seed(12345)
torch.manual_seed(12345)
torch.cuda.manual_seed_all(12345)
assert config.bsize % config.nranks == 0, (config.bsize, config.nranks)
config.bsize = config.bsize // config.nranks
print("Using config.bsize =", config.bsize, "(per process) and config.accumsteps =", config.accumsteps)
if collection is not None:
if config.reranker:
reader = RerankBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks)
else:
reader = LazyBatcher(config, triples, queries, collection, (0 if config.rank == -1 else config.rank), config.nranks)
else:
raise NotImplementedError()
if not config.reranker:
colbert = ColBERT(name=config.checkpoint, colbert_config=config)
else:
colbert = ElectraReranker.from_pretrained(config.checkpoint)
colbert = colbert.to(DEVICE)
colbert.train()
colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[config.rank],
output_device=config.rank,
find_unused_parameters=True)
optimizer = AdamW(filter(lambda p: p.requires_grad, colbert.parameters()), lr=config.lr, eps=1e-8)
optimizer.zero_grad()
scheduler = None
if config.warmup is not None:
print(f"#> LR will use {config.warmup} warmup steps and linear decay over {config.maxsteps} steps.")
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=config.warmup,
num_training_steps=config.maxsteps)
warmup_bert = config.warmup_bert
if warmup_bert is not None:
set_bert_grad(colbert, False)
amp = MixedPrecisionManager(config.amp)
labels = torch.zeros(config.bsize, dtype=torch.long, device=DEVICE)
start_time = time.time()
train_loss = None
train_loss_mu = 0.999
start_batch_idx = 0
# if config.resume:
# assert config.checkpoint is not None
# start_batch_idx = checkpoint['batch']
# reader.skip_to_batch(start_batch_idx, checkpoint['arguments']['bsize'])
for batch_idx, BatchSteps in zip(range(start_batch_idx, config.maxsteps), reader):
if (warmup_bert is not None) and warmup_bert <= batch_idx:
set_bert_grad(colbert, True)
warmup_bert = None
this_batch_loss = 0.0
for batch in BatchSteps:
with amp.context():
try:
queries, passages, target_scores = batch
encoding = [queries, passages]
except:
encoding, target_scores = batch
encoding = [encoding.to(DEVICE)]
scores = colbert(*encoding)
if config.use_ib_negatives:
scores, ib_loss = scores
scores = scores.view(-1, config.nway)
if len(target_scores) and not config.ignore_scores:
target_scores = torch.tensor(target_scores).view(-1, config.nway).to(DEVICE)
target_scores = target_scores * config.distillation_alpha
target_scores = torch.nn.functional.log_softmax(target_scores, dim=-1)
log_scores = torch.nn.functional.log_softmax(scores, dim=-1)
loss = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)(log_scores, target_scores)
else:
loss = nn.CrossEntropyLoss()(scores, labels[:scores.size(0)])
if config.use_ib_negatives:
if config.rank < 1:
print('\t\t\t\t', loss.item(), ib_loss.item())
loss += ib_loss
loss = loss / config.accumsteps
if config.rank < 1:
print_progress(scores)
amp.backward(loss)
this_batch_loss += loss.item()
train_loss = this_batch_loss if train_loss is None else train_loss
train_loss = train_loss_mu * train_loss + (1 - train_loss_mu) * this_batch_loss
amp.step(colbert, optimizer, scheduler)
if config.rank < 1:
print_message(batch_idx, train_loss)
manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None)
if config.rank < 1:
print_message("#> Done with all triples!")
ckpt_path = manage_checkpoints(config, colbert, optimizer, batch_idx+1, savepath=None, consumed_all_triples=True)
return ckpt_path # TODO: This should validate and return the best checkpoint, not just the last one.
def set_bert_grad(colbert, value):
try:
for p in colbert.bert.parameters():
assert p.requires_grad is (not value)
p.requires_grad = value
except AttributeError:
set_bert_grad(colbert.module, value)