Skip to content

Commit

Permalink
Merge pull request #56 from jaebeom-kim/master
Browse files Browse the repository at this point in the history
Metabuli 1.0.4
  • Loading branch information
jaebeom-kim committed Mar 13, 2024
2 parents 6fdf834 + d26d051 commit 9dd746e
Show file tree
Hide file tree
Showing 14 changed files with 340 additions and 241 deletions.
168 changes: 98 additions & 70 deletions src/commons/Classifier.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "Classifier.h"
#include "FileUtil.h"
#include "QueryIndexer.h"
#include "common.h"

Classifier::Classifier(LocalParameters & par) {
Expand Down Expand Up @@ -36,91 +37,118 @@ Classifier::~Classifier() {
}

void Classifier::startClassify(const LocalParameters &par) {

cout << "Indexing query file ...";
queryIndexer->indexQueryFile();
size_t numOfSeq = queryIndexer->getReadNum_1();
size_t totalReadLength = queryIndexer->getTotalReadLength();
const vector<QuerySplit> & queryReadSplit = queryIndexer->getQuerySplits();
cout << "Done" << endl;
cout << "Total number of sequences: " << numOfSeq << endl;
cout << "Total read length: " << totalReadLength << "nt" << endl;

QueryKmerBuffer kmerBuffer;
QueryKmerBuffer queryKmerBuffer;
Buffer<Match> matchBuffer;
vector<Query> queryList;

size_t numOfTatalQueryKmerCnt = 0;
size_t processedSeqCnt = 0;

reporter->openReadClassificationFile();
#ifdef OPENMP
omp_set_num_threads(par.threads);
#endif

bool complete = false;
size_t processedReadCnt = 0;
size_t tries = 0;
size_t totalSeqCnt = 0;

// Extract k-mers from query sequences and compare them to target k-mer DB
KSeqWrapper* kseq1 = KSeqFactory(par.filenames[0].c_str());
KSeqWrapper* kseq2 = nullptr;
if (par.seqMode == 2) { kseq2 = KSeqFactory(par.filenames[1].c_str()); }
// while (true) {
// bool success = false;
// while (!success) {
//
// }
// if (complete) {
// break;
// }
// }

for (size_t splitIdx = 0; splitIdx < queryReadSplit.size(); splitIdx++) {
// Allocate memory for query list
queryList.clear();
queryList.resize(queryReadSplit[splitIdx].end - queryReadSplit[splitIdx].start);

// Allocate memory for query k-mer list and match list
kmerBuffer.reallocateMemory(queryReadSplit[splitIdx].kmerCnt);
if (queryReadSplit.size() == 1) {
size_t remain = queryIndexer->getAvailableRam() - queryReadSplit[splitIdx].kmerCnt * sizeof(QueryKmer) - numOfSeq * 200;
matchBuffer.reallocateMemory(remain / sizeof(Match));
} else {
matchBuffer.reallocateMemory(queryReadSplit[splitIdx].kmerCnt * matchPerKmer);
while (!complete) {
tries++;

// Get splits for remaining sequences
if (tries == 1) {
cout << "Indexing query file ...";
}
queryIndexer->setBytesPerKmer(matchPerKmer);
queryIndexer->indexQueryFile(processedReadCnt);
const vector<QuerySplit> & queryReadSplit = queryIndexer->getQuerySplits();

if (tries == 1) {
totalSeqCnt = queryIndexer->getReadNum_1();
cout << "Done" << endl;
cout << "Total number of sequences: " << queryIndexer->getReadNum_1() << endl;
cout << "Total read length: " << queryIndexer->getTotalReadLength() << "nt" << endl;
}

// Set up kseq
KSeqWrapper* kseq1 = KSeqFactory(par.filenames[0].c_str());
KSeqWrapper* kseq2 = nullptr;
if (par.seqMode == 2) { kseq2 = KSeqFactory(par.filenames[1].c_str()); }

// Move kseq to unprocessed reads
for (size_t i = 0; i < processedReadCnt; i++) {
kseq1->ReadEntry();
if (par.seqMode == 2) { kseq2->ReadEntry(); }
}

// Initialize query k-mer buffer and match buffer
kmerBuffer.startIndexOfReserve = 0;
matchBuffer.startIndexOfReserve = 0;

// Extract query k-mer
kmerExtractor->extractQueryKmers(kmerBuffer,
queryList,
queryReadSplit[splitIdx],
par,
kseq1,
kseq2);
numOfTatalQueryKmerCnt += kmerBuffer.startIndexOfReserve;

// Search matches between query and target k-mers
kmerMatcher->matchKmers(&kmerBuffer, &matchBuffer);
kmerMatcher->sortMatches(&matchBuffer);

// Classify queries based on the matches.
// omp_set_num_threads(1);
taxonomer->assignTaxonomy(matchBuffer.buffer, matchBuffer.startIndexOfReserve, queryList, par);
processedSeqCnt += queryReadSplit[splitIdx].end - queryReadSplit[splitIdx].start;
cout << "The number of processed sequences: " << processedSeqCnt << " (" << (double) processedSeqCnt / (double) numOfSeq << ")" << endl;

// Write classification results
reporter->writeReadClassification(queryList);
for (size_t splitIdx = 0; splitIdx < queryReadSplit.size(); splitIdx++) {
// Allocate memory for query list
queryList.clear();
queryList.resize(queryReadSplit[splitIdx].end - queryReadSplit[splitIdx].start);

// Allocate memory for query k-mer buffer
queryKmerBuffer.reallocateMemory(queryReadSplit[splitIdx].kmerCnt);

// Allocate memory for match buffer
if (queryReadSplit.size() == 1) {
size_t remain = queryIndexer->getAvailableRam()
- (queryReadSplit[splitIdx].kmerCnt * sizeof(QueryKmer))
- (queryIndexer->getReadNum_1() * 200); // TODO: check it later
matchBuffer.reallocateMemory(remain / sizeof(Match));
} else {
matchBuffer.reallocateMemory(queryReadSplit[splitIdx].kmerCnt * matchPerKmer);
}

// Initialize query k-mer buffer and match buffer
queryKmerBuffer.startIndexOfReserve = 0;
matchBuffer.startIndexOfReserve = 0;

// Extract query k-mers
kmerExtractor->extractQueryKmers(queryKmerBuffer,
queryList,
queryReadSplit[splitIdx],
par,
kseq1,
kseq2); // sync kseq1 and kseq2

// Search matches between query and target k-mers
if (kmerMatcher->matchKmers(&queryKmerBuffer, &matchBuffer)) {
kmerMatcher->sortMatches(&matchBuffer);

// Classify queries based on the matches.
taxonomer->assignTaxonomy(matchBuffer.buffer, matchBuffer.startIndexOfReserve, queryList, par);

// Write classification results
reporter->writeReadClassification(queryList);

// Print progress
processedReadCnt += queryReadSplit[splitIdx].readCnt;
cout << "The number of processed sequences: " << processedReadCnt << " (" << (double) processedReadCnt / (double) totalSeqCnt << ")" << endl;

numOfTatalQueryKmerCnt += queryKmerBuffer.startIndexOfReserve;
} else { // search was incomplete
// Increase matchPerKmer and try again
matchPerKmer *= 2;
// delete kseq1;
// delete kseq2;
cout << "The search was incomplete. Increasing --match-per-kmer to " << matchPerKmer << " and trying again." << endl;
break;
}
}

delete kseq1;
if (par.seqMode == 2) {
delete kseq2;
}
if (processedReadCnt == totalSeqCnt) {
complete = true;
}
}

cout << "Number of query k-mers: " << numOfTatalQueryKmerCnt << endl;
cout << "The number of matches: " << kmerMatcher->getTotalMatchCnt() << endl;
reporter->closeReadClassificationFile();

// Write report files
reporter->writeReportFile(numOfSeq, taxonomer->getTaxCounts());
reporter->writeReportFile(totalSeqCnt, taxonomer->getTaxCounts());

// Memory deallocation
free(matchBuffer.buffer);
delete kseq1;
delete kseq2;
}
20 changes: 9 additions & 11 deletions src/commons/KmerExtractor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,14 @@ void KmerExtractor::fillQueryKmerBufferParallel(KSeqWrapper *kseq1,
QueryKmerBuffer &kmerBuffer,
vector<Query> &queryList,
const QuerySplit &currentSplit,
const LocalParameters &par) {
size_t queryNum = currentSplit.end - currentSplit.start;
const LocalParameters &par) {
size_t processedQueryNum = 0;

// Array to store reads of thread number
vector<string> reads1(par.threads);

while (processedQueryNum < queryNum) {
size_t currentQueryNum = min(queryNum - processedQueryNum, (size_t) par.threads);
// Array to store reads of thread number
vector<string> reads1(par.threads);
while (processedQueryNum < currentSplit.readCnt) {
size_t currentQueryNum = min(currentSplit.readCnt - processedQueryNum, (size_t) par.threads);
size_t count = 0;
while (count < currentQueryNum) {
// Read query
Expand Down Expand Up @@ -120,15 +119,14 @@ void KmerExtractor::fillQueryKmerBufferParallel_paired(KSeqWrapper *kseq1,
vector<Query> &queryList,
const QuerySplit &currentSplit,
const LocalParameters &par) {
size_t queryNum = currentSplit.end - currentSplit.start;
size_t processedQueryNum = 0;

// Array to store reads of thread number
vector<string> reads1(par.threads);
vector<string> reads2(par.threads);

while (processedQueryNum < queryNum) {
size_t currentQueryNum = min(queryNum - processedQueryNum, (size_t) par.threads);
while (processedQueryNum < currentSplit.readCnt) {
size_t currentQueryNum = min(currentSplit.readCnt - processedQueryNum, (size_t) par.threads);
size_t count = 0;

// Fill reads in sequential
Expand Down
23 changes: 12 additions & 11 deletions src/commons/KmerMatcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,15 @@ void KmerMatcher::loadTaxIdList(const LocalParameters & par) {
}


int KmerMatcher::matchKmers(QueryKmerBuffer * queryKmerBuffer,
Buffer<Match> * matchBuffer,
const string & db){
bool KmerMatcher::matchKmers(QueryKmerBuffer * queryKmerBuffer,
Buffer<Match> * matchBuffer,
const string & db){
// Set database files
string targetDiffIdxFileName;
string targetInfoFileName;
string diffIdxSplitFileName;
if (db.empty()) {
targetDiffIdxFileName = dbDir + "/diffIdx";
targetInfoFileName = dbDir + "/info";
diffIdxSplitFileName = dbDir + "/split";
} else {
} else { // for the case of multiple databases
targetDiffIdxFileName = dbDir + "/" + db + "/diffIdx";
targetInfoFileName = dbDir + "/" + db + "/info";
diffIdxSplitFileName = dbDir + "/" + db + "/split";
Expand Down Expand Up @@ -460,17 +457,18 @@ querySplits, queryKmerList, matchBuffer, cout, targetDiffIdxFileName, numOfDiffI
free(diffIdxBuffer);
free(kmerInfoBuffer);
} // End of omp parallel

if (hasOverflow) {
std::cout << "overflow!!!" << std::endl;
return 2;
return false;
}
} // end of while(completeSplitCnt < threadNum)

std::cout << "Time spent for the comparison: " << double(time(nullptr) - beforeSearch) << std::endl;
free(splitCheckList);
queryKmerNum = 0;

totalMatchCnt += matchBuffer->startIndexOfReserve;
return 1;
return true;
}

void KmerMatcher::sortMatches(Buffer<Match> * matchBuffer) {
Expand Down Expand Up @@ -542,5 +540,8 @@ bool KmerMatcher::compareMatches(const Match& a, const Match& b) {
if (a.qInfo.pos != b.qInfo.pos)
return a.qInfo.pos < b.qInfo.pos;

return a.hamming < b.hamming;
if (a.hamming != b.hamming)
return a.hamming < b.hamming;

return a.dnaEncoding < b.dnaEncoding;
}
10 changes: 8 additions & 2 deletions src/commons/KmerMatcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class KmerMatcher {
unordered_map<TaxID, TaxID> taxId2speciesId;
unordered_map<TaxID, TaxID> taxId2genusId;

string targetDiffIdxFileName;
string targetInfoFileName;
string diffIdxSplitFileName;


struct QueryKmerSplit {
QueryKmerSplit(size_t start, size_t end, size_t length,
const DiffIdxSplit &diffIdxSplit)
Expand Down Expand Up @@ -101,8 +106,9 @@ class KmerMatcher {

virtual ~KmerMatcher();

int matchKmers(QueryKmerBuffer *queryKmerBuffer, Buffer<Match> *matchBuffer,
const string &db = string());
bool matchKmers(QueryKmerBuffer *queryKmerBuffer,
Buffer<Match> *matchBuffer,
const string &db = string());

void sortMatches(Buffer<Match> *matchBuffer);

Expand Down
8 changes: 8 additions & 0 deletions src/commons/LocalParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,13 @@ LocalParameters::LocalParameters() :
typeid(std::string),
(void *) &cladeRank,
"^.*$"),
SKIP_SECONDARY(SKIP_SECONDARY_ID,
"--skip-secondary",
"Skip the results of already observed reads. (0: No, 1: Yes)",
"Skip secondary classification",
typeid(int),
(void *) &skipSecondary,
"[0-1]"),
PRINT_MODE(PRINT_MODE_ID,
"--print-mode",
"[1] Only filtered reads [2] Both filtered and removed reads",
Expand Down Expand Up @@ -363,6 +370,7 @@ LocalParameters::LocalParameters() :
grade.push_back(&COVERAGE_COL);
grade.push_back(&PRINT_COLUMNS);
grade.push_back(&CLADE_RANK);
grade.push_back(&SKIP_SECONDARY);

// Apply thresholds
applyThreshold.push_back(&MIN_SP_SCORE);
Expand Down
2 changes: 2 additions & 0 deletions src/commons/LocalParameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class LocalParameters : public Parameters {
PARAMETER(COVERAGE_COL)
PARAMETER(PRINT_COLUMNS)
PARAMETER(CLADE_RANK)
PARAMETER(SKIP_SECONDARY)

// Filter
PARAMETER(PRINT_MODE)
Expand Down Expand Up @@ -126,6 +127,7 @@ class LocalParameters : public Parameters {
int scoreCol;
int coverageCol;
std::string cladeRank;
int skipSecondary;

// Add to library
bool assembly;
Expand Down
2 changes: 1 addition & 1 deletion src/commons/QueryFilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void QueryFilter::printFilteredReads() {
void QueryFilter::filterReads(LocalParameters & par) {

cout << "Indexing query file ...";
queryIndexer->indexQueryFile();
queryIndexer->indexQueryFile(0);
size_t numOfSeq = queryIndexer->getReadNum_1();
size_t totalReadLength = queryIndexer->getTotalReadLength();
const vector<QuerySplit> & queryReadSplit = queryIndexer->getQuerySplits();
Expand Down
Loading

0 comments on commit 9dd746e

Please sign in to comment.