In [None]:
from nbdev.showdoc import *

In [1]:
from functools import cached_property
from pathlib import Path
from pandas import DataFrame
import logging
import numpy as np

from mokapot.constants import CHUNK_SIZE_READ_ALL_DATA
from mokapot.dataset import (
    OnDiskPsmDataset,
)

from mokapot.parsers.pin import parse_in_chunks, read_percolator
from mokapot.brew import make_train_sets

from misugaru.core import *

In [2]:
class Data:
    def __init__(self, psm_file: Path):
        self.psms: OnDiskPsmDataset = read_percolator(psm_file, max_workers=MAX_WORKERS)
        if not self.size > 1:
            raise ValueError("Dataset contains no PSMs")
    
    def get_train_psms_splits(self, n_subset: int) -> [DataFrame, DataFrame, DataFrame]:

        # Mokapot functions often expect lists of datasets
        train_idx = list(make_train_sets(
            test_idx=[self.fold_idx],  # expects list
            subset_max_train=n_subset,
            data_size=[self.size],  # expects list
            rng=RNG,
        ))
        fold_a, fold_b, fold_c = parse_in_chunks(
            datasets=[self.psms],  # expects list
            train_idx=train_idx,
            chunk_size=CHUNK_SIZE_READ_ALL_DATA,
            max_workers=MAX_WORKERS, 
        )
        del train_idx
        return fold_a, fold_b, fold_c 

    @cached_property
    def fold_idx(self):
        return self.psms._split(N_FOLDS, RNG)
    
    @cached_property
    def n_decoys(self) -> np.int64:
        return (~self.psms.spectra_dataframe[self.psms.target_column]).sum()

    @cached_property
    def n_targets(self) -> np.int64:
        return (self.psms.spectra_dataframe[self.psms.target_column]).sum()

    @cached_property
    def size(self) -> np.int64:
        return len(self.psms.spectra_dataframe)

In [None]:
# Usage
path = Path("~/repos/matcha/data/10k_psms_test.parquet")
data = Data(path)

In [None]:
data.n_decoys

np.int64(4698)

In [None]:
data.n_targets

np.int64(5302)

In [None]:
data.size

10000

In [None]:
a, b, c = data.get_train_psms_splits(n_subset=10)
a

Unnamed: 0,SpecId,Label,ScanNr,ExpMass,Mass,MS8_feature_5,missedCleavages,MS8_feature_7,MS8_feature_13,MS8_feature_20,MS8_feature_21,MS8_feature_22,MS8_feature_24,MS8_feature_29,MS8_feature_30,MS8_feature_32,Peptide,Proteins
913,11393,True,2111,853.465767,853.465767,7,1,5.646877,0.830555,0.714286,1.142857,16.142857,13,5.526901,0.709493,0.571429,_.180002._,_.dummy._
8787,8561,True,7734,2192.104823,2192.104823,18,2,6.923104,0.97211,0.277778,1.944444,25.055556,40,7.133587,0.217704,0.222222,_.3629503._,_.dummy._
3720,3626,False,4474,1092.519987,1092.519987,9,0,6.54987,0.436707,0.333333,0.666667,25.222222,9,6.4632,0.48736,0.333333,_.776503._,_.dummy._
8754,11991,True,7696,1889.995925,1889.995925,17,1,6.620053,0.891212,0.176471,1.529412,19.941176,29,6.604076,0.256311,0.176471,_.2864803._,_.dummy._
6916,5224,True,6522,1257.68297,1257.68297,11,1,6.993753,0.02311,1.181818,0.636364,23.272727,20,7.041632,0.70508,1.0,_.1238503._,_.dummy._
5325,10781,True,5472,1214.589998,1214.589998,11,0,7.317326,1.493152,0.727273,1.454545,19.818182,24,7.338107,0.371911,0.363636,_.1130802._,_.dummy._
3415,122,False,4314,1080.527381,1080.527381,10,1,7.345293,0.743355,0.5,0.9,20.8,14,7.273375,0.519051,0.4,_.743303._,_.dummy._
3456,6281,True,4335,1317.732623,1317.732623,11,1,6.05795,1.351949,0.454545,1.636364,28.0,23,6.186919,0.38394,0.272727,_.1396802._,_.dummy._
9823,5964,True,8620,1281.690364,1281.690364,10,2,5.653715,1.479566,0.3,1.2,32.2,15,6.096302,0.170965,0.3,_.1300302._,_.dummy._
63,12359,False,209,775.346046,775.346046,7,0,5.634303,1.008405,0.571429,0.571429,28.428571,8,5.736019,0.344076,0.571429,_.52302._,_.dummy._
