-
Notifications
You must be signed in to change notification settings - Fork 23
/
write_selected_data.py
103 lines (83 loc) · 4.37 KB
/
write_selected_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import argparse
import os
import torch
def parse_args():
argparser = argparse.ArgumentParser(
description='Script for selecting the data for training')
argparser.add_argument('--train_file_names', type=str,
nargs='+', help='The path to the score file')
argparser.add_argument('--train_files', type=str, nargs='+',
help='The path of the training file that corresponds to the score file')
argparser.add_argument('--target_task_names', type=str,
nargs='+', help='The name of the target task')
argparser.add_argument('--output_path', type=str,
default="selected_data", help='The path to the output')
argparser.add_argument('--max_samples', type=int,
default=None, help='The maximum number of samples')
argparser.add_argument('--percentage', type=float, default=None,
help='The percentage of the data to be selected')
args = argparser.parse_args()
return args
def count_lines(filename):
with open(filename, 'r', encoding='utf-8', errors='ignore') as file:
line_count = 0
for line in file:
line_count += 1
return line_count
if __name__ == "__main__":
args = parse_args()
assert len(args.train_file_names) == len(args.train_files)
assert args.percentage is not None or args.max_samples is not None
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_train_files = len(args.train_file_names)
for target_task in args.target_task_names:
output_path = os.path.join(args.output_path, target_task)
score_paths = [os.path.join(
output_path, f"{task_name}_influence_score.pt") for task_name in args.train_file_names]
num_samples = []
for score_path in score_paths:
num_samples.append(
len(torch.load(score_path, map_location=device)))
cumsum_num_samples = torch.cumsum(torch.tensor(num_samples), dim=0)
total_samples = sum(num_samples)
if args.percentage is not None:
args.max_samples = int(args.percentage * total_samples)
data_amount_name = f"p{args.percentage}"
else:
data_amount_name = f"num{args.max_samples}"
all_scores = []
for score_path, train_file in zip(score_paths, args.train_files):
score = torch.load(score_path, map_location=device)
all_scores.append(score)
all_scores = torch.cat(all_scores, dim=0)
# sort the scores and output the corresponding data index
file_specific_index = torch.cat(
[torch.arange(line_num) for line_num in num_samples]).to(device)
data_from = torch.cat([torch.ones(line_num, dtype=torch.long)
* i for i, line_num in enumerate(num_samples)]).to(device)
sorted_scores, sorted_index = torch.sort(
all_scores, dim=0, descending=True)
sorted_score_file = os.path.join(output_path, f"sorted.csv")
data_from = data_from[sorted_index]
sorted_index = file_specific_index[sorted_index]
if not os.path.exists(sorted_score_file):
with open(sorted_score_file, 'w', encoding='utf-8') as file:
file.write("file name, index, score\n")
for score, index, name in zip(sorted_scores, sorted_index, data_from):
file.write(
f"{args.train_file_names[name.item()]}, {index.item()}, {round(score.item(), 6)}\n")
topk_scores, topk_indices = torch.topk(
all_scores.float(), args.max_samples, dim=0, largest=True)
all_lines = []
for i, train_file in enumerate(args.train_files):
with open(train_file, 'r', encoding='utf-8', errors='ignore') as file:
all_lines.append(file.readlines()[:num_samples[i]])
final_index_list = sorted_index[:args.max_samples].tolist()
final_data_from = data_from[:args.max_samples].tolist()
with open(os.path.join(output_path, f"top_{data_amount_name}.jsonl"), 'w', encoding='utf-8', errors='ignore') as file:
for index, data_from in zip(final_index_list, final_data_from):
try:
file.write(all_lines[data_from][index])
except:
import pdb
pdb.set_trace()