## Summary

----

## Imports

In [1]:
import os
from pathlib import Path

import tqdm

import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

## Parameters

In [2]:
SUBSET = "training"
RANDOM_SEED = 0

In [3]:
DATAPKG_OUTPUT_DIR = Path("~/datapkg_output_dir").expanduser().resolve()
DATAPKG_OUTPUT_DIR

PosixPath('/home/kimlab_sysmaster/datapkg_output_dir')

In [4]:
DEEP_PROTEIN_GEN_DIR = DATAPKG_OUTPUT_DIR / "deep-protein-gen"
DEEP_PROTEIN_GEN_DIR

PosixPath('/home/kimlab_sysmaster/datapkg_output_dir/deep-protein-gen')

## Workflow

In [5]:
DATA_DIR = DEEP_PROTEIN_GEN_DIR / f"{SUBSET}_data"
DATA_DIR

PosixPath('/home/kimlab_sysmaster/datapkg_output_dir/deep-protein-gen/training_data')

In [6]:
OUTPUT_DATA_DIR = DEEP_PROTEIN_GEN_DIR / f"{SUBSET}_data_rs{RANDOM_SEED}"
OUTPUT_DATA_DIR

PosixPath('/home/kimlab_sysmaster/datapkg_output_dir/deep-protein-gen/training_data_rs0')

In [7]:
folder_list = os.listdir(DATA_DIR)
folder_list[:3]

['database_id=G3DSA%3A1.10.10.180',
 'database_id=G3DSA%3A1.10.10.190',
 'database_id=G3DSA%3A1.10.10.520']

In [8]:
columns = [
    "uniparc_id",
    "structure_id",
    "sequence",
    "residue_idx_1_corrected",
    "residue_idx_2_corrected",
    "distances",
]
columns

['uniparc_id',
 'structure_id',
 'sequence',
 'residue_idx_1_corrected',
 'residue_idx_2_corrected',
 'distances']

In [None]:
random_state = np.random.RandomState(RANDOM_SEED)

num_rows_per_group = 200

for folder in tqdm.tqdm_notebook(folder_list, total=len(folder_list)):
    file_list = list((DATA_DIR / folder).glob("*.parquet"))
    assert len(file_list) == 1

    database_id = folder.replace("database_id=G3DSA%3A", "")

    pqfile = pq.ParquetFile(file_list[0])
    num_rows = pqfile.metadata.num_rows
    mask = random_state.rand(num_rows) < (
        num_rows_per_group / num_rows * max(1, num_rows / num_rows_per_group) ** 0.5
    )

    dfs = []
    start = 0
    for row_group in range(pqfile.num_row_groups):
        df = pqfile.read_row_group(
            row_group, use_threads=False, columns=columns
        ).to_pandas(integer_object_nulls=True)
        end = start + len(df)
        df = df.loc[mask[start:end], :]
        assert len(df.index) == len(df.index.drop_duplicates())
        dfs.append(df)
        start = end
    assert start == num_rows
    df = pd.concat(dfs, ignore_index=True, sort=False)
    df["database_id"] = database_id 

    print(database_id, num_rows, mask.sum(), len(df))
    table = pa.Table.from_pandas(df, preserve_index=False)
    pq.write_table(table, OUTPUT_DATA_DIR / f"{database_id}.parquet")

HBox(children=(IntProgress(value=0, max=998), HTML(value='')))

1.10.10.180 1311 522 522
1.10.10.190 11 11 11
1.10.10.520 1138 475 475
1.10.10.570 1517 567 567
1.10.10.610 1913 639 639
1.10.10.630 9050 1424 1424
1.10.10.650 14292 1768 1768
1.10.10.690 3995 885 885
1.10.100.10 4609 1009 1009
1.10.1000.11 8532 1330 1330
1.10.101.10 68856 3641 3641
1.10.1020.10 6375 1119 1119
1.10.1030.10 33286 2595 2595
1.10.1040.20 5674 1051 1051
1.10.1060.10 91169 4198 4198
1.10.1070.11 14548 1781 1781
1.10.1080.10 3654 853 853
1.10.1090.10 1965 634 634
1.10.1100.10 418 281 281
1.10.1140.10 32444 2518 2518
1.10.1160.10 32080 2606 2606
1.10.12.10 75697 3916 3916
1.10.1200.10 245548 7020 7020
1.10.1200.20 2647 715 715
1.10.1200.60 2168 662 662
1.10.1200.70 9788 1345 1345
1.10.1200.80 15158 1702 1702
1.10.1220.10 90561 4145 4145
1.10.1240.10 16330 1829 1829
1.10.1240.20 3546 764 764
1.10.1240.30 1500 528 528
1.10.1290.10 108 108 108
1.10.1300.10 16450 1790 1790
1.10.1320.10 3007 786 786
1.10.135.10 3979 921 921
1.10.1390.10 10155 1415 1415
1.10.140.10 854 400 400
1.10