## Summary

----

## Imports

In [1]:
import os
import time
from pathlib import Path

import tqdm

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

## Parameters

In [2]:
SUBSET = "training"
RANDOM_SEED = 0

In [3]:
DATAPKG_OUTPUT_DIR = Path("~/datapkg_output_dir").expanduser().resolve()
DATAPKG_OUTPUT_DIR

PosixPath('/home/kimlab_sysmaster/datapkg_output_dir')

In [4]:
DEEP_PROTEIN_GEN_DIR = DATAPKG_OUTPUT_DIR / "deep-protein-gen"
DEEP_PROTEIN_GEN_DIR

PosixPath('/home/kimlab_sysmaster/datapkg_output_dir/deep-protein-gen')

## Workflow

In [5]:
DATA_DIR = DEEP_PROTEIN_GEN_DIR / f"{SUBSET}_data"
DATA_DIR

PosixPath('/home/kimlab_sysmaster/datapkg_output_dir/deep-protein-gen/training_data')

In [6]:
OUTPUT_DATA_DIR = DEEP_PROTEIN_GEN_DIR / f"{SUBSET}_data_rs{RANDOM_SEED}"
OUTPUT_DATA_DIR

PosixPath('/home/kimlab_sysmaster/datapkg_output_dir/deep-protein-gen/training_data_rs0')

In [7]:
folder_list = os.listdir(DATA_DIR)
folder_list[:3]

['database_id=G3DSA%3A1.10.10.180',
 'database_id=G3DSA%3A1.10.10.190',
 'database_id=G3DSA%3A1.10.10.520']

In [8]:
columns = [
    "uniparc_id",
    "structure_id",
    "sequence",
    "residue_idx_1_corrected",
    "residue_idx_2_corrected",
    "distances",
]
columns

['uniparc_id',
 'structure_id',
 'sequence',
 'residue_idx_1_corrected',
 'residue_idx_2_corrected',
 'distances']

In [None]:
random_state = np.random.RandomState(RANDOM_SEED)

num_rows_per_group = 200

for folder in tqdm.tqdm_notebook(folder_list, total=len(folder_list)):
    file_list = list((DATA_DIR / folder).glob("*.parquet"))
    assert len(file_list) == 1

    database_id = folder.replace("database_id=G3DSA%3A", "")

    output_file = OUTPUT_DATA_DIR / f"{database_id}.parquet"
    if output_file.is_file():
        try:
            df = pq.read_table(output_file, use_threads=False).to_pandas(
                integer_object_nulls=True
            )
            assert len(df) > (num_rows_per_group // 2)
            continue
        except Exception as e:
            print(type(e), e)

    pqfile = pq.ParquetFile(file_list[0])
    num_rows = pqfile.metadata.num_rows
    mask = random_state.rand(num_rows) < (
        num_rows_per_group / num_rows * max(1, num_rows / num_rows_per_group) ** 0.5
    )

    dfs = []
    start = 0
    for row_group in range(pqfile.num_row_groups):
        df = pqfile.read_row_group(
            row_group, use_threads=False, columns=columns
        ).to_pandas(integer_object_nulls=True)
        end = start + len(df)
        df = df.loc[mask[start:end], :]
        assert len(df.index) == len(df.index.drop_duplicates())
        dfs.append(df)
        start = end
    assert start == num_rows
    df = pd.concat(dfs, ignore_index=True, sort=False)
    df["database_id"] = database_id

    print(database_id, num_rows, mask.sum(), len(df))
    table = pa.Table.from_pandas(df, preserve_index=False)
    
    n_tries = 0
    while n_tries < 5:
        n_tries += 1
        try:
            pq.write_table(table, output_file)
            break
        except OSError:
            time.sleep(5)

HBox(children=(IntProgress(value=0, max=998), HTML(value='')))

<class 'AssertionError'> 
1.10.10.190 11 11 11
<class 'AssertionError'> 
1.10.287.120 9 9 9
1.10.3080.10 46125 3157 3157
<class 'AssertionError'> 
1.10.3200.10 29 29 29
<class 'AssertionError'> 
1.10.3400.10 19 19 19
<class 'AssertionError'> 
1.10.3960.10 37 37 37
<class 'AssertionError'> 
1.20.10.10 16 16 16
<class 'AssertionError'> 
1.20.120.590 32 32 32
<class 'AssertionError'> 
1.20.120.630 93 93 93
<class 'AssertionError'> 
1.20.1280.10 57 57 57
<class 'AssertionError'> 
1.20.1400.10 8 8 8
<class 'AssertionError'> 
1.20.1480.10 28 28 28
<class 'AssertionError'> 
1.20.200.20 77 77 77
1.20.210.10 867948 13126 13126
1.20.225.10 348 265 265
1.20.5.100 40429 2811 2811
1.20.5.1010 707 373 373
1.20.5.1130 150 150 150
1.20.5.120 1189 492 492
1.20.5.130 233 222 222
1.20.5.210 1585 597 597
1.20.5.220 1894 597 597
1.20.5.260 1157 524 524
1.20.5.270 1497 548 548
1.20.5.550 2724 715 715
1.20.5.560 1461 546 546
1.20.5.600 137 137 137
1.20.5.620 23607 2166 2166
1.20.5.630 1172 474 474
1.20.5.790