Skip to content

Commit

Permalink
Add filter threshold and new nn model
Browse files Browse the repository at this point in the history
  • Loading branch information
martin-steinegger committed Sep 18, 2018
1 parent 768aedd commit 2854ec5
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 38 deletions.
2 changes: 2 additions & 0 deletions data/CMakeLists.txt
Expand Up @@ -5,6 +5,8 @@ set(COMPILED_RESOURCES
hybridassembler.sh
nuclassembler.sh
predict_coding_acc9260_56x96.model
predict_coding_acc9623_57x32x64.model
predict_coding_acc9540_57x32x64.model
)

set(GENERATED_OUTPUT_HEADERS "")
Expand Down
2 changes: 1 addition & 1 deletion data/assembler.sh
Expand Up @@ -126,7 +126,7 @@ if [ -n "${PROTEIN_FILTER}" ]; then
RESULT="${TMP_PATH}/assembly_${STEP}_filtered"
if notExists "${TMP_PATH}/assembly_${STEP}_filtered"; then
# shellcheck disable=SC2086
"$MMSEQS" filternoncoding "${TMP_PATH}/assembly_${STEP}" "${TMP_PATH}/assembly_${STEP}_filtered" ${THREADS_PAR} \
"$MMSEQS" filternoncoding "${TMP_PATH}/assembly_${STEP}" "${TMP_PATH}/assembly_${STEP}_filtered" ${FILTERNONCODING_PAR} \
|| fail "filternoncoding died"
fi
fi
Expand Down
Binary file added data/predict_coding_acc9540_57x32x64.model
Binary file not shown.
Binary file added data/predict_coding_acc9623_57x32x64.model
Binary file not shown.
120 changes: 88 additions & 32 deletions src/assembler/filternoncoding.cpp
Expand Up @@ -4,13 +4,16 @@
#include "SubstitutionMatrix.h"

#include "Debug.h"
#include "AminoAcidLookupTables.h"
#include "DBReader.h"
#include "DBWriter.h"

#include "LocalParameters.h"

#include "kerasify/keras_model.h"
#include "predict_coding_acc9260_56x96.model.h"
#include "predict_coding_acc9540_57x32x64.model.h"
//#include "predict_coding_acc9260_56x96.model.h"
//#include "predict_coding_acc9623_57x32x64.model.h"

#ifdef OPENMP
#include <omp.h>
Expand All @@ -30,10 +33,12 @@ int filternoncoding(int argc, const char **argv, const Command& command) {

// Initialize model.
KerasModel model;
model.LoadModel(std::string((const char *)predict_coding_acc9260_56x96_model, predict_coding_acc9260_56x96_model_len));
// model.LoadModel(std::string((const char *)predict_coding_acc9623_57x32x64_model, predict_coding_acc9623_57x32x64_model_len));
model.LoadModel(std::string((const char *)predict_coding_acc9540_57x32x64_model, predict_coding_acc9540_57x32x64_model_len));

SubstitutionMatrix subMat(par.scoringMatrixFile.c_str(), 2.0, 0.0);
ReducedMatrix redMat(subMat.probMatrix, subMat.subMatrixPseudoCounts, 7, subMat.getBitFactor());
ReducedMatrix redMat7(subMat.probMatrix, subMat.subMatrixPseudoCounts, 7, subMat.getBitFactor());
// ReducedMatrix redMat3(subMat.probMatrix, subMat.subMatrixPseudoCounts, 3, subMat.getBitFactor());

// Create a 1D Tensor on length 20 for input data.
#pragma omp parallel
Expand All @@ -47,20 +52,31 @@ int filternoncoding(int argc, const char **argv, const Command& command) {
float counter[255];
std::fill(counter, counter + 255, 1.0);
Sequence seq(par.maxSeqLen, Sequence::AMINO_ACIDS, &subMat, par.kmerSize, false, false);
Sequence rseq(par.maxSeqLen, Sequence::AMINO_ACIDS, &redMat, 2, false, false);
Indexer indexer(redMat.alphabetSize, 2);
float *diAACnt = new float[redMat.alphabetSize * redMat.alphabetSize];
std::fill(diAACnt, diAACnt + redMat.alphabetSize * redMat.alphabetSize, 1.0);
Sequence rseq2mer(par.maxSeqLen, Sequence::AMINO_ACIDS, &redMat7, 2, false, false);
// Sequence rseq5mer(par.maxSeqLen, Sequence::AMINO_ACIDS, &redMat3, 5, false, false);
Indexer indexerDi(redMat7.alphabetSize, 2);
// Indexer indexerPenta(redMat3.alphabetSize, 5);
float *diAACnt = new float[redMat7.alphabetSize * redMat7.alphabetSize];
std::fill(diAACnt, diAACnt + redMat7.alphabetSize * redMat7.alphabetSize, 1.0);
// int pentaIdxRange = MathUtil::ipow<size_t>(redMat3.alphabetSize, 5);
// float *pentaAACnt = new float[pentaIdxRange];
// std::fill(pentaAACnt, pentaAACnt + pentaIdxRange, 1.0);
Charges charge;
Doolittle doolittle;

#pragma omp for schedule(static)
for (size_t id = 0; id < seqDb.getSize(); id++) {
std::vector<float> data;
char *seqData = seqDb.getData(id);
unsigned int dbKey = seqDb.getDbKey(id);
seq.mapSequence(id, dbKey, seqData);

//printf("%5d ", seq.L);
// printf("%5d ", seq.L);
// float chargeFlt = Util::averageValueOnAminoAcids(charge.values, seqData);
// float doolittleFlt = Util::averageValueOnAminoAcids(doolittle.values, seqData);
// printf("%3f ", chargeFlt);
// printf("%3f ", doolittleFlt);
float totalAACnt = 0;
data.push_back(static_cast<float>(seq.L));
for (int pos = 0; pos < seq.L; pos++) {
if (seq.int_sequence[pos] < subMat.alphabetSize - 1) {
counter[seq.int_sequence[pos]] += 1.0;
Expand All @@ -70,38 +86,77 @@ int filternoncoding(int argc, const char **argv, const Command& command) {
for (int aa = 0; aa < subMat.alphabetSize - 1; aa++) {
data.push_back(counter[aa] / (totalAACnt + subMat.alphabetSize - 1));
counter[aa] = 1.0;
//printf("%.4f ", data.back());
// printf("%.4f ", data.back());
}

rseq.mapSequence(id, dbKey, seqData);
float totalDiAACnt = 0;
while (rseq.hasNextKmer()) {
const int *kmer = rseq.nextKmer();
// ignore x
if (kmer[0] == redMat.alphabetSize - 1 || kmer[1] == redMat.alphabetSize - 1) {
continue;
// di matrix
{
rseq2mer.mapSequence(id, dbKey, seqData);
float totalDiAACnt = 0;
while (rseq2mer.hasNextKmer()) {
const int *kmer = rseq2mer.nextKmer();
// ignore x
if (static_cast<int>(kmer[0]) == redMat7.alphabetSize - 1 ||
static_cast<int>(kmer[1]) == redMat7.alphabetSize - 1) {
continue;
}
size_t index = indexerDi.int2index(kmer);
diAACnt[index] += 1.0;
totalDiAACnt += 1.0;
}
size_t index = indexer.int2index(kmer);
diAACnt[index] += 1.0;
totalDiAACnt += 1.0;
}
size_t kmer[2];
for (int raa = 0; raa < (redMat.alphabetSize * redMat.alphabetSize); raa++) {
indexer.index2int(kmer, raa, 2);
if (kmer[0] == redMat.alphabetSize - 1 || kmer[1] == redMat.alphabetSize - 1) {
continue;
size_t kmer[2];
float diRealRange = static_cast<float>((redMat7.alphabetSize - 1) * (redMat7.alphabetSize - 1));
for (int raa = 0; raa < (redMat7.alphabetSize * redMat7.alphabetSize); raa++) {
indexerDi.index2int(kmer, raa, 2);
if (static_cast<int>(kmer[0]) == redMat7.alphabetSize - 1 ||
static_cast<int>(kmer[1]) == redMat7.alphabetSize - 1) {
continue;
}
data.push_back(diAACnt[raa] / (totalDiAACnt + diRealRange));
// printf("%.4f ", data.back());
diAACnt[raa] = 1.0;
}
data.push_back(diAACnt[raa] / (totalDiAACnt + (redMat.alphabetSize - 1) * (redMat.alphabetSize - 1)));
//printf("%.4f ", data.back());
diAACnt[raa] = 1.0;
}
//printf("\n");
// 5mer
// {
// rseq5mer.mapSequence(id, dbKey, seqData);
// float totalPentaAACnt = 0;
// while (rseq5mer.hasNextKmer()) {
// const int *kmer = rseq5mer.nextKmer();
// // ignore x
// if (kmer[0] == redMat3.alphabetSize - 1 ||
// kmer[1] == redMat3.alphabetSize - 1 ||
// kmer[2] == redMat3.alphabetSize - 1 ||
// kmer[3] == redMat3.alphabetSize - 1 ||
// kmer[4] == redMat3.alphabetSize - 1) {
// continue;
// }
// size_t index = indexerPenta.int2index(kmer);
// pentaAACnt[index] += 1.0;
// totalPentaAACnt += 1.0;
// }
// size_t kmer5[5];
// int pentaRealRange = static_cast<float>(MathUtil::ipow<size_t>(redMat3.alphabetSize - 1, 5));
// for (int raa = 0; raa < pentaIdxRange; raa++) {
// indexerPenta.index2int(kmer5, raa, 5);
// if (static_cast<int>(kmer5[0]) == redMat3.alphabetSize - 1 ||
// static_cast<int>(kmer5[1]) == redMat3.alphabetSize - 1 ||
// static_cast<int>(kmer5[2]) == redMat3.alphabetSize - 1 ||
// static_cast<int>(kmer5[3]) == redMat3.alphabetSize - 1 ||
// static_cast<int>(kmer5[4]) == redMat3.alphabetSize - 1) {
// continue;
// }
// data.push_back(pentaAACnt[raa] / (totalPentaAACnt + pentaRealRange));
// printf("%.4f ", data.back());
// pentaAACnt[raa] = 1.0;
// }
// }
// printf("\n");
//printf("%d\n", data.size());
in.data_ = data;
// Run prediction.
Tensor out;
model.Apply(&in, &out);
if (out.data_[0] > 0.2) {
if (out.data_[0] > par.proteinFilterThreshold) {
// -1 dont write \0 byte
dbw.writeData(seqData, seqDb.getSeqLens(id) - 1, dbKey, thread_idx);
} else {
Expand All @@ -111,6 +166,7 @@ int filternoncoding(int argc, const char **argv, const Command& command) {
}

delete[] diAACnt;
// delete[] pentaAACnt;
}
// std::cout << "Filtered: " << static_cast<float>(cnt)/ static_cast<float>(seqDb.getSize()) << std::endl;
dbw.close(Sequence::AMINO_ACIDS);
Expand Down
16 changes: 14 additions & 2 deletions src/commons/LocalParameters.h
Expand Up @@ -18,18 +18,22 @@ class LocalParameters : public Parameters {

std::vector<MMseqsParameter> assembleresults;
std::vector<MMseqsParameter> extractorfssubset;
std::vector<MMseqsParameter> filternoncoding;
std::vector<MMseqsParameter> hybridassembleresults;
std::vector<MMseqsParameter> assemblerworkflow;
std::vector<MMseqsParameter> nuclassemblerworkflow;

PARAMETER(PARAM_FILTER_PROTEINS)
PARAMETER(PARAM_PROTEIN_FILTER_THRESHOLD)

int filterProteins;
float proteinFilterThreshold;

private:
LocalParameters() :
Parameters(),
PARAM_FILTER_PROTEINS(PARAM_FILTER_PROTEINS_ID,"--filter-proteins", "Filter Proteins", "filter proteins by a neural network [0,1]",typeid(int), (void *) &filterProteins, "^[0-1]{1}$")

PARAM_FILTER_PROTEINS(PARAM_FILTER_PROTEINS_ID,"--filter-proteins", "Filter Proteins", "filter proteins by a neural network [0,1]",typeid(int), (void *) &filterProteins, "^[0-1]{1}$"),
PARAM_PROTEIN_FILTER_THRESHOLD(PARAM_PROTEIN_FILTER_THRESHOLD_ID,"--protein-filter-threshold", "Protein Filter Threshold", "filter proteins lower than threshold [0.0,1.0]",typeid(float), (void *) &proteinFilterThreshold, "^0(\\.[0-9]+)?|1(\\.0+)?$")
{
// assembleresult
assembleresults.push_back(PARAM_MIN_SEQ_ID);
Expand All @@ -41,10 +45,16 @@ class LocalParameters : public Parameters {
extractorfssubset.push_back(PARAM_THREADS);
extractorfssubset.push_back(PARAM_V);

filternoncoding.push_back(PARAM_PROTEIN_FILTER_THRESHOLD);
filternoncoding.push_back(PARAM_THREADS);
filternoncoding.push_back(PARAM_V);

// assembler workflow
assemblerworkflow = combineList(rescorediagonal, kmermatcher);
assemblerworkflow = combineList(assemblerworkflow, extractorfs);
assemblerworkflow = combineList(assemblerworkflow, assembleresults);
assemblerworkflow = combineList(assemblerworkflow, filternoncoding);

assemblerworkflow.push_back(PARAM_FILTER_PROTEINS);
assemblerworkflow.push_back(PARAM_NUM_ITERATIONS);
assemblerworkflow.push_back(PARAM_REMOVE_TMP_FILES);
Expand All @@ -64,6 +74,8 @@ class LocalParameters : public Parameters {
hybridassembleresults.push_back(PARAM_RUNNER);

filterProteins = 1;
proteinFilterThreshold = 0.2;

}
LocalParameters(LocalParameters const&);
~LocalParameters() {};
Expand Down
2 changes: 1 addition & 1 deletion src/plass.cpp
Expand Up @@ -154,7 +154,7 @@ std::vector<struct Command> commands = {
"Martin Steinegger <martin.steinegger@mpibpc.mpg.de>",
"<i:sequenceDB> <i:alignmentDB> <o:sequenceDB>",
CITATION_MMSEQS2},
{"filternoncoding", filternoncoding, &par.onlythreads, COMMAND_HIDDEN,
{"filternoncoding", filternoncoding, &par.filternoncoding, COMMAND_HIDDEN,
"Filter non-coding protein sequences",
NULL,
"Martin Steinegger <martin.steinegger@mpibpc.mpg.de>",
Expand Down
2 changes: 1 addition & 1 deletion src/workflow/Assembler.cpp
Expand Up @@ -140,12 +140,12 @@ int assembler(int argc, const char **argv, const Command &command) {
cmd.addVariable("EXTRACTORFS_START_PAR", par.createParameterString(par.extractorfs).c_str());



par.addOrfStop = true;
cmd.addVariable("CREATEDB_PAR", par.createParameterString(par.createdb).c_str());
cmd.addVariable("TRANSLATENUCS_PAR", par.createParameterString(par.translatenucs).c_str());
cmd.addVariable("UNGAPPED_ALN_PAR", par.createParameterString(par.rescorediagonal).c_str());
cmd.addVariable("ASSEMBLE_RESULT_PAR", par.createParameterString(par.assembleresults).c_str());
cmd.addVariable("FILTERNONCODING_PAR", par.createParameterString(par.filternoncoding).c_str());

cmd.addVariable("THREADS_PAR", par.createParameterString(par.onlythreads).c_str());
cmd.addVariable("VERBOSITY_PAR", par.createParameterString(par.onlyverbosity).c_str());
Expand Down

0 comments on commit 2854ec5

Please sign in to comment.