## Загрузим необходимые данные

In [1]:
from pathlib import Path
import pandas as pd
import json
import jsonlines

DATA_DIRECTORY = Path('./downloads/data/lean/')  # change as necessary

In [2]:
with open(DATA_DIRECTORY / 'all_lemmas_statements.json') as fh:
    lemmas = json.load(fh)

ds = pd.read_parquet(DATA_DIRECTORY / 'arguments_full_names.parquet')

ds.head()

Unnamed: 0,goal,tactic,decl,args_names,other_args,split
0,"α : Type u,\n_inst_1 : inhabited α,\ni b_fst :...",unfold read read',buffer.read_eq_read',buffer.read\nbuffer.read',,test
1,"α : Type u,\n_inst_1 : inhabited α,\ni b_fst :...",simp [array.read_eq_read'],buffer.read_eq_read',array.read_eq_read',,test
2,"α : Type u,\n_inst_1 : inhabited α,\nb : buffe...",cases b,buffer.read_eq_read',,,test
3,"α : Type u,\ni : ℕ,\nv : α,\nb_fst : ℕ,\nb_snd...",unfold write write',buffer.write_eq_write',buffer.write\nbuffer.write',,train
4,"α : Type u,\ni : ℕ,\nv : α,\nb_fst : ℕ,\nb_snd...",simp [array.write_eq_write'],buffer.write_eq_write',array.write_eq_write',,train


In [3]:
ds['args_names'].value_counts()

                                                114371
rfl                                               3387
rfl\nrfl                                           473
mul_comm                                           388
refl                                               382
                                                 ...  
mv_polynomial.coeff_monomial\nfinsupp.single         1
mv_polynomial.monomial_eq                            1
add_monoid_algebra.support_mul                       1
mv_polynomial.support_monomial\nif_neg               1
option.some_injective\noption.some_ne_none           1
Name: args_names, Length: 43045, dtype: int64

Примерно в половине случаев отсутствуют аргументы.

Колонка **goal** используется как `question`.

In [4]:
args_names = ds['args_names'][0].split()

args_statements = [lemmas[name] for name in args_names]
for stm in args_statements:
    print(stm)

def read : Π (b : buffer α), fin b.size → α	| ⟨n, a⟩ i := a.read i
def read' [inhabited α] : buffer α → nat → α	| ⟨n, a⟩ i := a.read' i


In [5]:
lemmas

{'': 'def {u} cond {a : Type u} : bool → a → a → a\t| tt x y := x\t| ff x y := y',
 'A': 'def A : finset (agreed_triple C J) := finset.univ.filter (λ (a : agreed_triple C J), a.judge_pair.agree r a.contestant ∧ a.judge_pair.distinct)',
 'ADE_inequality.A': "def A  (r   : ℕ+) : multiset ℕ+ := A' 1 r",
 "ADE_inequality.A'": "def A' (q r : ℕ+) : multiset ℕ+ := {1,q,r}",
 "ADE_inequality.D'": "def D' (r   : ℕ+) : multiset ℕ+ := {2,2,r}",
 "ADE_inequality.E'": "def E' (r   : ℕ+) : multiset ℕ+ := {2,3,r}",
 'ADE_inequality.E6': "def E6            : multiset ℕ+ := E' 3",
 'ADE_inequality.E7': "def E7            : multiset ℕ+ := E' 4",
 'ADE_inequality.E8': "def E8            : multiset ℕ+ := E' 5",
 'ADE_inequality.admissible': "def admissible (pqr : multiset ℕ+) : Prop := (∃ q r, A' q r = pqr) ∨ (∃ r, D' r = pqr) ∨ (E' 3 = pqr ∨ E' 4 = pqr ∨ E' 5 = pqr)",
 'ADE_inequality.admissible.one_lt_sum_inv': 'lemma admissible.one_lt_sum_inv {pqr : multiset ℕ+} : admissible pqr → 1 < sum_inv pqr',
 "A

В качестве `context` берутся пары (название леммы, формулировка), где название леммы проходит по всем элементам в **args_names**.

## Формирование файла с леммами для DPR retrieval

In [6]:
import pandas as pd

In [7]:
lemmas_df = pd.DataFrame(lemmas.items(), columns=['id', 'text'])
lemmas_df['title'] = lemmas_df['id']

lemmas_df.to_csv(DATA_DIRECTORY / 'all_lemmas.tsv', index=None, sep='\t')

## Нужная структура датасета

In [8]:
print("""
[
  {
    "question": "....",
    "answers": ["...", "...", "..."],
    "positive_ctxs": [{
        "title": "...",
        "text": "...."
    }],
    "negative_ctxs": ["..."],
    "hard_negative_ctxs": ["..."],
    "passage_id": '...'
  },
  ...
]

""")


[
  {
    "question": "....",
    "answers": ["...", "...", "..."],
    "positive_ctxs": [{
        "title": "...",
        "text": "...."
    }],
    "negative_ctxs": ["..."],
    "hard_negative_ctxs": ["..."],
    "passage_id": '...'
  },
  ...
]




## Формируем датасет

In [9]:
def make_data_example(line, all_lemmas):
    question = line["goal"]
    positive_titles = line["args_names"].split()
    positive_statements = [all_lemmas[arg] for arg in positive_titles]

    example = {"question": question,
               "positive_ctxs": [{"title": title, "text": statement, "passage_id": title} 
                                 for title, statement in zip(positive_titles, positive_statements)],
               "negative_ctxs": [],
               "hard_negative_ctxs": []}
    
    if len(positive_titles) == 0:
        example['positive_ctxs'] = [{'title': '<None>', 'text': '<None>', 'passage_id': '<None>'}]
        
    return example

In [10]:
for i, line in ds[:1].iterrows():
    example = make_data_example(line, lemmas)
example

{'question': "α : Type u,\n_inst_1 : inhabited α,\ni b_fst : ℕ,\nb_snd : array b_fst α,\nh : i < buffer.size ⟨b_fst, b_snd⟩\n⊢ buffer.read ⟨b_fst, b_snd⟩ ⟨i, h⟩ = buffer.read' ⟨b_fst, b_snd⟩ i",
 'positive_ctxs': [{'title': 'buffer.read',
   'text': 'def read : Π (b : buffer α), fin b.size → α\t| ⟨n, a⟩ i := a.read i',
   'passage_id': 'buffer.read'},
  {'title': "buffer.read'",
   'text': "def read' [inhabited α] : buffer α → nat → α\t| ⟨n, a⟩ i := a.read' i",
   'passage_id': "buffer.read'"}],
 'negative_ctxs': [],
 'hard_negative_ctxs': []}

In [11]:
dataset_train = [make_data_example(line, lemmas) for _, line in ds[ds.split == 'train'].iterrows()]
assert len(dataset_train) == len(ds[ds.split == 'train'])

dataset_dev = [make_data_example(line, lemmas) for _, line in ds[ds.split == 'valid'].iterrows()]
assert len(dataset_dev) == len(ds[ds.split == 'valid'])

Проанализируем полученный датасет

In [24]:
import matplotlib.pyplot as plt
%matplotlib inline

In [25]:
pd.Series([len(item['positive_ctxs']) for item in dataset_train]).value_counts()

1     154299
2      17170
3       6813
4       2903
5       1440
6        664
7        378
8        232
9        108
10        62
12        42
11        42
13        22
14        19
15        15
17        11
18        10
16         8
19         6
20         4
21         3
28         3
22         2
26         2
33         2
27         2
23         2
37         1
65         1
43         1
60         1
29         1
45         1
25         1
74         1
54         1
dtype: int64

In [26]:
dataset_train[:5]

[{'question': "α : Type u,\ni : ℕ,\nv : α,\nb_fst : ℕ,\nb_snd : array b_fst α,\nh : i < buffer.size ⟨b_fst, b_snd⟩\n⊢ buffer.write ⟨b_fst, b_snd⟩ ⟨i, h⟩ v = buffer.write' ⟨b_fst, b_snd⟩ i v",
  'positive_ctxs': [{'title': 'buffer.write',
    'text': 'def write : Π (b : buffer α), fin b.size → α → buffer α\t| ⟨n, a⟩ i v := ⟨n, a.write i v⟩',
    'passage_id': 'buffer.write'},
   {'title': "buffer.write'",
    'text': "def write' : buffer α → nat → α → buffer α\t| ⟨n, a⟩ i v := ⟨n, a.write' i v⟩",
    'passage_id': "buffer.write'"}],
  'negative_ctxs': [],
  'hard_negative_ctxs': []},
 {'question': "α : Type u,\ni : ℕ,\nv : α,\nb_fst : ℕ,\nb_snd : array b_fst α,\nh : i < buffer.size ⟨b_fst, b_snd⟩\n⊢ ⟨b_fst, b_snd.write ⟨i, h⟩ v⟩ = ⟨b_fst, b_snd.write' i v⟩",
  'positive_ctxs': [{'title': "array.write_eq_write'",
    'text': "lemma write_eq_write' (a : array n α) {i : nat} (h : i < n) (v : α) : write a ⟨i, h⟩ v = write' a i v",
    'passage_id': "array.write_eq_write'"}],
  'negative_ctx

Сохраняем

In [27]:
with jsonlines.open(DATA_DIRECTORY / 'lean-questions-train.json', mode='w') as f:
    f.write(dataset_train)

In [28]:
with jsonlines.open(DATA_DIRECTORY / 'lean-questions-dev.json', mode='w') as f:
    f.write(dataset_dev)

Проверяем

In [29]:
print(len(dataset_train))
dataset_train[4]

184273


{'question': 'α : Type,\np : parser α\n⊢ ∀ (x : char_buffer), p.bind parser.pure x = p x',
 'positive_ctxs': [{'title': '<None>',
   'text': '<None>',
   'passage_id': '<None>'}],
 'negative_ctxs': [],
 'hard_negative_ctxs': []}

### Создаем датасет, в котором все пары (параграф, лемма) на отдельной строке

In [30]:
def create_one_line_data(data):
    for i in range(len(data)):
        if len(data[i]['positive_ctxs']) == 0:
            data[i]['positive_ctxs'] = [{'title': '<None>', 'text': '<None>'}]

    dataset_per_line = sum([[{'question': para['question'], 
                              'positive_ctxs': [context], 
                              'negative_ctxs': para['negative_ctxs'],
                              'hard_negative_ctxs': para['hard_negative_ctxs']} 
                             for context in para['positive_ctxs']] 
                            for para in data], [])
    
    return dataset_per_line

In [31]:
for mode in ['train', 'dev']:
    with open(DATA_DIRECTORY / f'lean-questions-{mode}.json', "r") as f:
        data = json.load(f)

    one_line_data = create_one_line_data(data)

    with jsonlines.open(DATA_DIRECTORY / f'lean-questions-one-lemma-{mode}.json', mode="w") as f:
        f.write(one_line_data)