Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Michael Diskin committed Sep 17, 2021
1 parent 62dabc7 commit d6dd190
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
4 changes: 2 additions & 2 deletions config.json
Expand Up @@ -23,9 +23,9 @@
"net_structure_type": 0,
"num_attention_heads": 64,
"num_hidden_groups": 1,
"num_hidden_layers": 12,
"num_hidden_layers": 8,
"num_memory_blocks": 0,
"pad_token_id": 0,
"type_vocab_size": 2,
"vocab_size": 30000
}
}
10 changes: 5 additions & 5 deletions dataset.py
Expand Up @@ -14,12 +14,12 @@ def __init__(self, X, Y=None, device="cpu"):
self.X = X
self.Y = Y
self.device = device

def __len__(self):
return len(self.X["input_ids"])

def __getitem__(self, index):
to_return = {key: torch.LongTensor(value[index])[:512].to(self.device) for key, value in self.X.items()}
to_return = {key: torch.LongTensor(value[index]).to(self.device) for key, value in self.X.items()}
to_return["index"] = index
if self.Y is not None:
to_return["labels"] = self.Y[index].to(self.device)
Expand All @@ -30,7 +30,7 @@ def make_dataset(tokenizer, data, has_labels=True, device="cpu",
answer_field="answer", pos_label=True):
questions = [elem[first_key] for elem in data]
passages = [elem[second_key] for elem in data]
X = tokenizer(text=questions, text_pair=passages, truncation=True)
X = tokenizer(text=questions, text_pair=passages, truncation=True, max_length=512)
if has_labels:
Y = torch.FloatTensor([int(elem[answer_field]==pos_label) for elem in data])
else:
Expand All @@ -39,7 +39,7 @@ def make_dataset(tokenizer, data, has_labels=True, device="cpu",


class OrderedBatchSampler(Sampler):

def __init__(self, data, batch_size, length_func=None, shuffle=True, random_state=187):
if length_func is None:
length_func = lambda x: 0
Expand All @@ -62,4 +62,4 @@ def make_dataloader(dataset, batch_size=16, shuffle=True, key="input_ids"):
length_func = lambda x: len(x[key]) if key else None
sampler = OrderedBatchSampler(dataset, batch_size=batch_size,
length_func=length_func, shuffle=shuffle)
return DataLoader(dataset, collate_fn=collate_fn, batch_sampler=sampler)
return DataLoader(dataset, collate_fn=collate_fn, batch_sampler=sampler)
17 changes: 13 additions & 4 deletions main.py
Expand Up @@ -78,20 +78,26 @@ def get_status(corr, pred):
args = argument_parser.parse_args()
train_data = read_infile(args.train_file)
dev_data = read_infile(args.dev_file)

config = LeanAlbertConfig.from_pretrained(args.model_config)
tokenizer = AlbertTokenizerFast.from_pretrained(args.tokenizer_path, return_token_type_ids=True)
model = LeanAlbertModel(config)
model.resize_token_embeddings(len(tokenizer))
# model.load_state_dict(torch.load(args.model_path)["model"])
c =torch.load(args.model_path)["model"]
a = {}
for key in c.keys():
if "albert" in key:
a[key[7:]] = c[key]
model.load_state_dict(a)


train_dataset = make_dataset(tokenizer, train_data, pos_label=args.pos_label,
answer_field=args.answer_field,
first_key=args.first_sentence,
second_key=args.second_sentence,
device="cuda:0")
dev_dataset = make_dataset(tokenizer, dev_data, pos_label=args.pos_label,
answer_field=args.answer_field,
answer_field=args.answer_field,
first_key=args.first_sentence,
second_key=args.second_sentence,
device="cuda:0")
Expand All @@ -111,7 +117,7 @@ def get_status(corr, pred):
if args.load_file:
bert_classifier.load_state_dict(torch.load(args.load_file))
if args.train:
model.train()
bert_classifier.train()
for epoch in range(args.nepochs):
progress_bar = tqdm.tqdm(train_dataloader)
metrics = initialize_metrics()
Expand All @@ -122,6 +128,7 @@ def get_status(corr, pred):
if (args.eval_every_n_batches > 0 and i % args.eval_every_n_batches == 0 and
len(train_dataloader) - i >= args.eval_every_n_batches // 2) or\
i == len(train_dataloader):
bert_classifier.eval()
dev_metrics = initialize_metrics()
dev_progress_bar = tqdm.tqdm(dev_dataloader)
for j, batch in enumerate(dev_progress_bar):
Expand All @@ -131,6 +138,8 @@ def get_status(corr, pred):
if dev_metrics["accuracy"] > best_score:
best_score = dev_metrics["accuracy"]
best_weights = copy.deepcopy(bert_classifier.state_dict())
bert_classifier.train()

bert_classifier.load_state_dict(best_weights)
## загружаем наилучшее состояние
bert_classifier.eval()
Expand Down

0 comments on commit d6dd190

Please sign in to comment.