Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Merge pull request BVLC#2193 from immars/datumfile
Plain Datumfile input format for minimum memory usage
  • Loading branch information
weiliu89 committed Apr 14, 2015
2 parents c1d560e + d03d600 commit 7b259c7
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 4 deletions.
13 changes: 9 additions & 4 deletions examples/imagenet/create_imagenet.sh
Expand Up @@ -9,6 +9,9 @@ TOOLS=build/tools
TRAIN_DATA_ROOT=/path/to/imagenet/train/
VAL_DATA_ROOT=/path/to/imagenet/val/

# leveldb|lmdb|datumfile
DB_TYPE=datumfile

# Set RESIZE=true to resize the images to 256x256. Leave as false if images have
# already been resized using another tool.
RESIZE=false
Expand All @@ -34,24 +37,26 @@ if [ ! -d "$VAL_DATA_ROOT" ]; then
exit 1
fi

echo "Creating train lmdb..."
echo "Creating train $DB_TYPE ..."

GLOG_logtostderr=1 $TOOLS/convert_imageset \
--resize_height=$RESIZE_HEIGHT \
--resize_width=$RESIZE_WIDTH \
--shuffle \
--backend=$DB_TYPE \
$TRAIN_DATA_ROOT \
$DATA/train.txt \
$EXAMPLE/ilsvrc12_train_lmdb
$EXAMPLE/ilsvrc12_train_$DB_TYPE

echo "Creating val lmdb..."
echo "Creating val $DB_TYPE ..."

GLOG_logtostderr=1 $TOOLS/convert_imageset \
--resize_height=$RESIZE_HEIGHT \
--resize_width=$RESIZE_WIDTH \
--shuffle \
--backend=$DB_TYPE \
$VAL_DATA_ROOT \
$DATA/val.txt \
$EXAMPLE/ilsvrc12_val_lmdb
$EXAMPLE/ilsvrc12_val_$DB_TYPE

echo "Done."
84 changes: 84 additions & 0 deletions include/caffe/util/db.hpp
Expand Up @@ -181,6 +181,90 @@ class LMDB : public DB {
MDB_dbi mdb_dbi_;
};


#define MAX_BUF 104857600 // max entry size
class DatumFileCursor : public Cursor {
public:
explicit DatumFileCursor(const string& path) {
this->path_ = path;
in_ = NULL;
SeekToFirst();
}
virtual ~DatumFileCursor() {
if (in_ != NULL && in_->is_open()) {
in_->close();
delete in_;
in_ = NULL;
}
}
virtual void SeekToFirst();

virtual void Next();

virtual string key() {
CHECK(valid()) << "not valid state at key()";
return key_;
}
virtual string value() {
CHECK(valid()) << "not valid state at value()";
return value_;
}

virtual bool valid() { return valid_; }

private:
string path_;
std::ifstream* in_;
bool valid_;

string key_, value_;
};

class DatumFileTransaction : public Transaction {
public:
explicit DatumFileTransaction(std::ofstream* out) {
this->out_ = out;
}

virtual void Put(const string& key, const string& value);

virtual void Commit() {
out_->flush();
}

private:
std::ofstream* out_;
DISABLE_COPY_AND_ASSIGN(DatumFileTransaction);
};


class DatumFileDB : public DB {
public:
DatumFileDB() { out_ = NULL; can_write_ = false;}
virtual ~DatumFileDB() { Close(); }
virtual void Open(const string& source, Mode mode) {
path_ = source;
this->can_write_ = mode != db::READ;
}
virtual void Close() {
if (out_ != NULL) {
out_->close();
delete out_;
out_ = NULL;
}
}
virtual DatumFileCursor* NewCursor() {
return new DatumFileCursor(this->path_);
}
virtual Transaction* NewTransaction();

private:
string path_;
std::ofstream* out_;

bool can_write_;
};

DB* GetDB(DataParameter::DB backend);
DB* GetDB(const string& backend);

Expand Down
1 change: 1 addition & 0 deletions src/caffe/proto/caffe.proto
Expand Up @@ -456,6 +456,7 @@ message DataParameter {
enum DB {
LEVELDB = 0;
LMDB = 1;
DATUMFILE = 2;
}
// Specify the data source.
optional string source = 1;
Expand Down
80 changes: 80 additions & 0 deletions src/caffe/util/db.cpp
Expand Up @@ -59,12 +59,90 @@ void LMDBTransaction::Put(const string& key, const string& value) {
MDB_CHECK(mdb_put(mdb_txn_, *mdb_dbi_, &mdb_key, &mdb_value, 0));
}


void DatumFileCursor::SeekToFirst() {
if (in_ && in_->is_open()) {
in_->close();
}
LOG(INFO) << "reset ifstream " << path_;
in_ = new std::ifstream(path_.c_str(),
std::ifstream::in|std::ifstream::binary);
Next();
}

void DatumFileCursor::Next() {
valid_ = false;
CHECK(in_->is_open()) << "file is not open!" << path_;

uint32_t record_size = 0, key_size = 0, value_size = 0;
in_->read(reinterpret_cast<char*>(&record_size), sizeof record_size);
if (in_->gcount() != (sizeof record_size) || record_size > MAX_BUF) {
CHECK(in_->eof() && record_size <= MAX_BUF)
<<"record_size read error: gcount\t"
<< in_->gcount() << "\trecord_size\t" << record_size;
return;
}

in_->read(reinterpret_cast<char*>(&key_size), sizeof key_size);
CHECK(in_->gcount() == sizeof key_size && key_size <= MAX_BUF)
<< "key_size read error: gcount\t"
<< in_->gcount() << "\tkey_size\t" << key_size;

key_.resize(key_size);
in_->read(&key_[0], key_size);
CHECK(in_->gcount() == key_size)
<< "key read error: gcount\t"
<< in_->gcount() << "\tkey_size\t" << key_size;

in_->read(reinterpret_cast<char*>(&value_size), sizeof value_size);
CHECK(in_->gcount() == sizeof value_size && value_size <= MAX_BUF)
<< "value_size read error: gcount\t"
<< in_->gcount() << "\tvalue_size\t" << value_size;

value_.resize(value_size);
in_->read(&value_[0], value_size);
CHECK(in_->gcount() == value_size)
<< "value read error: gcount\t"
<< in_->gcount() << "\tvalue_size\t" << value_size;

valid_ = true;
}

void DatumFileTransaction::Put(const string& key, const string& value) {
try {
uint32_t key_size = key.size(), value_size = value.size();
uint32_t record_size = key_size + value_size
+ sizeof key_size + sizeof value_size;
out_->write(reinterpret_cast<char*>(&record_size), sizeof record_size);
out_->write(reinterpret_cast<char*>(&key_size), sizeof key_size);
out_->write(key.data(), key_size);
out_->write(reinterpret_cast<char*>(&value_size), sizeof value_size);
out_->write(value.data(), value_size);
} catch(std::ios_base::failure& e) {
LOG(FATAL) << "Exception: "
<< e.what() << " rdstate: " << out_->rdstate() << '\n';
}
}

Transaction* DatumFileDB::NewTransaction() {
if (!this->out_) {
out_ = new std::ofstream();
out_->open(this->path_.c_str(),
std::ofstream::out | std::ofstream::trunc | std::ofstream::binary);
out_->exceptions(out_->exceptions() | std::ios::failbit);
LOG(INFO) << "Output created: " << path_ << std::endl;
}
return new DatumFileTransaction(this->out_);
}

DB* GetDB(DataParameter::DB backend) {
switch (backend) {
case DataParameter_DB_LEVELDB:
return new LevelDB();
case DataParameter_DB_LMDB:
return new LMDB();
case DataParameter_DB_DATUMFILE:
return new DatumFileDB();
default:
LOG(FATAL) << "Unknown database backend";
}
Expand All @@ -75,6 +153,8 @@ DB* GetDB(const string& backend) {
return new LevelDB();
} else if (backend == "lmdb") {
return new LMDB();
} else if (backend == "datumfile") {
return new DatumFileDB();
} else {
LOG(FATAL) << "Unknown database backend";
}
Expand Down

0 comments on commit 7b259c7

Please sign in to comment.