Skip to content

Commit

Permalink
restore khmer nodegraph compat
Browse files Browse the repository at this point in the history
  • Loading branch information
luizirber committed Mar 5, 2022
1 parent 39a5c8a commit e67852f
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 9 deletions.
2 changes: 2 additions & 0 deletions include/sourmash.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ uintptr_t nodegraph_noccupied(const SourmashNodegraph *ptr);

uintptr_t nodegraph_ntables(const SourmashNodegraph *ptr);

void nodegraph_save_khmer(const SourmashNodegraph *ptr, const char *filename);

void nodegraph_save(const SourmashNodegraph *ptr, const char *filename);

const uint8_t *nodegraph_to_buffer(const SourmashNodegraph *ptr,
Expand Down
17 changes: 17 additions & 0 deletions src/core/src/ffi/nodegraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,23 @@ unsafe fn nodegraph_save(ptr: *const SourmashNodegraph, filename: *const c_char)
}
}

ffi_fn! {
unsafe fn nodegraph_save_khmer(ptr: *const SourmashNodegraph, filename: *const c_char) -> Result<()> {
let ng = SourmashNodegraph::as_rust(ptr);

// FIXME use buffer + len instead of c_str
let c_str = {
assert!(!filename.is_null());

CStr::from_ptr(filename)
};

ng.write_v4(&mut std::fs::File::create(c_str.to_str()?)?)?;

Ok(())
}
}

ffi_fn! {
unsafe fn nodegraph_to_buffer(ptr: *const SourmashNodegraph, compression: u8, size: *mut usize) -> Result<*const u8> {
let ng = SourmashNodegraph::as_rust(ptr);
Expand Down
67 changes: 62 additions & 5 deletions src/core/src/sketch/nodegraph.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fs::File;
use std::io;
use std::path::Path;
use std::slice;
use std::{fs::File, io::BufWriter};

use bitmagic::BVector;
use byteorder::{BigEndian, ByteOrder, LittleEndian, ReadBytesExt, WriteBytesExt};
Expand Down Expand Up @@ -169,16 +170,24 @@ impl Nodegraph {
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), Error> {
// TODO: if it ends with gz, open a compressed file
// might use get_output here?
self.save_to_writer(&mut File::create(path)?)?;
let fp = File::create(path)?;
self.save_to_writer(&mut BufWriter::new(fp))?;
Ok(())
}

pub fn save_to_writer<W>(&self, wtr: &mut W) -> Result<(), Error>
where
W: io::Write,
{
self.write_v5(wtr)
}

fn write_v5<W>(&self, wtr: &mut W) -> Result<(), Error>
where
W: io::Write,
{
wtr.write_all(b"OXLI")?;
wtr.write_u8(99)?; // version
wtr.write_u8(5)?; // version
wtr.write_u8(2)?; // ht_type
wtr.write_u32::<LittleEndian>(self.ksize as u32)?; // ksize
wtr.write_u8(self.bs.len() as u8)?; // n_tables
Expand All @@ -196,6 +205,54 @@ impl Nodegraph {
Ok(())
}

pub(crate) fn write_v4<W>(&self, wtr: &mut W) -> Result<(), Error>
where
W: io::Write,
{
use fixedbitset::FixedBitSet;

wtr.write_all(b"OXLI")?;
wtr.write_u8(4)?; // version
wtr.write_u8(2)?; // ht_type
wtr.write_u32::<LittleEndian>(self.ksize as u32)?; // ksize
wtr.write_u8(self.bs.len() as u8)?; // n_tables
wtr.write_u64::<LittleEndian>(self.occupied_bins as u64)?; // n_occupied
for count in &self.bs {
let tablesize = count.len();
wtr.write_u64::<LittleEndian>(tablesize as u64)?;

let byte_size = tablesize / 8 + 1;
let (div, rem) = (byte_size / 4, byte_size % 4);

let mut fbs = FixedBitSet::with_capacity(tablesize);
fbs.extend(count.ones());

// Once this issue and PR are solved, this is a one liner:
// https://github.com/BurntSushi/byteorder/issues/155
// https://github.com/BurntSushi/byteorder/pull/166
//wtr.write_u32_from::<LittleEndian>(&count.as_slice()[..div])?;
let slice = &fbs.as_slice()[..div];
let buf = unsafe {
use std::mem::size_of;

let len = size_of::<u32>() * slice.len();
slice::from_raw_parts(slice.as_ptr() as *const u8, len)
};

wtr.write_all(&buf)?;
// Replace when byteorder PR is released

if rem != 0 {
let mut cursor = [0u8; 4];
LittleEndian::write_u32(&mut cursor, fbs.as_slice()[div]);
for item in cursor.iter().take(rem) {
wtr.write_u8(*item)?;
}
}
}
Ok(())
}

pub fn from_reader<R>(rdr: R) -> Result<Nodegraph, Error>
where
R: io::Read,
Expand All @@ -208,7 +265,7 @@ impl Nodegraph {
let version = rdr.read_u8()?;
match version {
4 => Self::read_v4(rdr),
99 => Self::read_v99(rdr),
5 => Self::read_v5(rdr),
_ => todo!("throw error, version not supported"),
}
}
Expand Down Expand Up @@ -262,7 +319,7 @@ impl Nodegraph {
})
}

fn read_v99<R>(mut rdr: R) -> Result<Nodegraph, Error>
fn read_v5<R>(mut rdr: R) -> Result<Nodegraph, Error>
where
R: io::Read,
{
Expand Down
10 changes: 7 additions & 3 deletions src/sourmash/nodegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,12 @@ def from_buffer(buf):
ng_ptr = rustcall(lib.nodegraph_from_buffer, buf, len(buf))
return Nodegraph._from_objptr(ng_ptr)

def save(self, filename):
self._methodcall(lib.nodegraph_save, to_bytes(filename))
def save(self, filename, version=5):
assert version >= 4
if version == 4:
self._methodcall(lib.nodegraph_save_khmer, to_bytes(filename))
else:
self._methodcall(lib.nodegraph_save, to_bytes(filename))

def to_bytes(self, compression=1):
size = ffi.new("uintptr_t *")
Expand Down Expand Up @@ -94,7 +98,7 @@ def to_khmer_nodegraph(self):
load_nodegraph = khmer.Nodegraph.load

with NamedTemporaryFile() as f:
self.save(f.name)
self.save(f.name, version=4)
f.file.flush()
f.file.seek(0)
return load_nodegraph(f.name)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_nodegraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_nodegraph_same_file():
khmer_ng = load_nodegraph(ng_file)

with NamedTemporaryFile() as f1, NamedTemporaryFile() as f2, NamedTemporaryFile() as f3:
sourmash_ng.save(f1.name)
sourmash_ng.save(f1.name, version=4)
khmer_sm_ng.save(f2.name)
khmer_ng.save(f3.name)

Expand Down

0 comments on commit e67852f

Please sign in to comment.