Skip to content

Commit

Permalink
marisa: Cleaned up everything
Browse files Browse the repository at this point in the history
  • Loading branch information
pgaskin committed Mar 23, 2020
1 parent 0b4392a commit 1a80ab6
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 142 deletions.
77 changes: 24 additions & 53 deletions marisa/marisa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,54 +5,31 @@
#include <string>

#include "libmarisa.h"
#include "marisa.h"
#include "shim.h"

#define try_cstr(out_err) \
*(out_err) = NULL; \
try

#define catch_cstr(out_err) \
catch (const marisa::Exception &ex) { \
const char* b = "marisa: "; \
char* err = reinterpret_cast<char*>( \
calloc(strlen(b)+strlen(ex.what())+1, sizeof(char))); \
strcpy(err, b); \
strcat(err, ex.what()); \
*(out_err) = err; \
return; \
} catch (const go::error &ex) { \
const char* b = "go shim error: "; \
char* err = reinterpret_cast<char*>( \
calloc(strlen(b)+strlen(ex.what())+1, sizeof(char))); \
strcpy(err, b); \
strcat(err, ex.what()); \
*(out_err) = err; \
return; \
} catch (const std::runtime_error &ex) { \
const char* b = "c++ runtime error: "; \
#define catch_go_ex(t, ctx) \
catch (const t &ex) { \
const char* b = ctx; \
char* err = reinterpret_cast<char*>( \
calloc(strlen(b)+strlen(ex.what())+1, sizeof(char))); \
strcpy(err, b); \
strcat(err, ex.what()); \
*(out_err) = err; \
return; \
} catch (const std::exception &ex) { \
const char* b = "c++ error: "; \
char* err = reinterpret_cast<char*>( \
calloc(strlen(b)+strlen(ex.what())+1, sizeof(char))); \
strcpy(err, b); \
strcat(err, ex.what()); \
*(out_err) = err; \
return; \
} catch (...) { \
*(out_err) = strdup("marisa: unknown c++ exception"); \
return; \
return err; \
}

extern "C" void marisa_read_all(int iid, char ***out_wd, size_t *out_wd_sz, char **out_err) {
try_cstr(out_err) {
if (!out_wd || !out_wd_sz || !out_err)
#define catch_go \
catch_go_ex(marisa::Exception, "marisa: ") \
catch_go_ex(go::error, "go shim: ") \
catch_go_ex(std::runtime_error, "c++ runtime: ") \
catch_go_ex(std::exception, "c++ error: ") \
catch (...) { return strdup("marisa: unknown c++ exception"); } \
return NULL;

#define go_func extern "C" const char*

go_func marisa_read_all(int iid, char ***out_wd, size_t *out_wd_sz) {
try {
if (!out_wd || !out_wd_sz)
throw std::runtime_error("parameter is null");
go::pstream r(iid);
marisa::Trie t;
Expand All @@ -68,25 +45,19 @@ extern "C" void marisa_read_all(int iid, char ***out_wd, size_t *out_wd_sz, char
}
if (*out_wd_sz != t.num_keys())
throw std::runtime_error("expected " + std::to_string(t.num_keys()) + " keys, got " + std::to_string(*out_wd_sz));
} catch_cstr(out_err)
} catch_go
}

extern "C" void marisa_write_all(int iid, const char** in_wd, size_t in_wd_sz, char **out_err) {
try_cstr(out_err) {
if ((in_wd_sz && !in_wd) || !out_err)
go_func marisa_write_all(int iid, const char** wd, size_t wd_sz) {
try {
if (wd_sz && !wd)
throw std::runtime_error("parameter is null");
marisa::Keyset k;
for (size_t i = 0; i < in_wd_sz; i++)
k.push_back(in_wd[i]);
for (size_t i = 0; i < wd_sz; i++)
k.push_back(wd[i]);
marisa::Trie t;
t.build(k);
go::pstream w(iid);
marisa::write(w, t);
} catch_cstr(out_err)
}

extern "C" void marisa_wd_free(char **wd, size_t wd_sz) {
for (size_t i = 0; i < wd_sz; i++)
free(wd[i]);
free(wd);
} catch_go
}
94 changes: 48 additions & 46 deletions marisa/marisa.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package marisa

//#cgo CPPFLAGS: -Wall
//#cgo LDFLAGS:
//#include <stddef.h>
//#include <stdlib.h>
//#include "marisa.h"
//const char* marisa_read_all(int iid, char ***out_wd, size_t *out_wd_sz);
//const char* marisa_write_all(int iid, const char** wd, size_t wd_sz);
import "C"

import (
Expand All @@ -18,65 +20,65 @@ import (

func ReadAll(r io.Reader) ([]string, error) {
iid := iopPut(r)
defer iopDel(iid)

var out_wd **C.char
var out_wd_sz C.size_t
var out_err *C.char

C.marisa_read_all(
err := C.marisa_read_all(
(C.int)(iid),
(***C.char)(unsafe.Pointer(&out_wd)),
(*C.size_t)(unsafe.Pointer(&out_wd_sz)),
(**C.char)(unsafe.Pointer(&out_err)),
)

if out_wd != nil {
defer C.marisa_wd_free(out_wd, out_wd_sz)
}
if out_err != nil {
defer C.free(unsafe.Pointer(out_err))
return nil, errors.New(C.GoString(out_err))
}

wd := make([]string, int(out_wd_sz))
for i, w := range (*[1 << 28]*C.char)(unsafe.Pointer(out_wd))[:int(out_wd_sz):int(out_wd_sz)] {
wd[i] = C.GoString(w)
}
return wd, nil
iopDel(iid)
return gostrs(out_wd, out_wd_sz), goerr(err)
}

func WriteAll(w io.Writer, wd []string) error {
iid := iopPut(w)
defer iopDel(iid)
wd_ptr, wd_sz, wd_free := cstrs(wd)
err := C.marisa_write_all(
(C.int)(iid),
(**C.char)(wd_ptr),
(C.size_t)(wd_sz),
)
wd_free()
iopDel(iid)
return goerr(err)
}

in_wd := make([]*C.char, len(wd))
for i, w := range wd {
in_wd[i] = C.CString(w)
func goerr(p *C.char) (err error) {
if p != nil {
err = errors.New(C.GoString(p))
C.free(unsafe.Pointer(p))
}
defer func() {
for _, p := range in_wd {
C.free(unsafe.Pointer(p))
}
}()

var out_err *C.char
return
}

var in_wd_ptr unsafe.Pointer
if len(in_wd) != 0 {
in_wd_ptr = unsafe.Pointer(&in_wd[0])
func gostrs(p **C.char, n C.size_t) (s []string) {
if p != nil {
s = make([]string, int(n))
for i, v := range (*[1 << 28]*C.char)(unsafe.Pointer(p))[:int(n):int(n)] {
s[i] = C.GoString(v)
C.free(unsafe.Pointer(v))
}
C.free(unsafe.Pointer(p))
}
C.marisa_write_all(
(C.int)(iid),
(**C.char)(in_wd_ptr),
(C.size_t)(len(in_wd)),
(**C.char)(unsafe.Pointer(&out_err)),
)
return
}

if out_err != nil {
defer C.free(unsafe.Pointer(out_err))
return errors.New(C.GoString(out_err))
func cstrs(s []string) (p **C.char, n C.size_t, free func()) {
n = (C.size_t)(len(s))
if len(s) == 0 {
free = func() {}
return
}

return nil
c := make([]*C.char, len(s))
for i, v := range s {
c[i] = C.CString(v)
}
p = (**C.char)(unsafe.Pointer(&c[0]))
free = func() {
for _, v := range c {
C.free(unsafe.Pointer(v))
}
}
return
}
15 changes: 0 additions & 15 deletions marisa/marisa.h

This file was deleted.

17 changes: 9 additions & 8 deletions marisa/shim.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package marisa

//#include <stddef.h>
import "C"

import (
Expand Down Expand Up @@ -63,7 +64,7 @@ func iopDel(iid int) {
}

//export go_iop_read
func go_iop_read(iid C.int, buf *C.char, buf_n C.int, out_err **C.char) C.int {
func go_iop_read(iid C.int, buf *C.char, buf_n C.size_t, out_err **C.char) C.int {
switch i := iopGet(int(iid)).(type) {
case io.Reader:
n, err := i.Read((*[1 << 28]byte)(unsafe.Pointer(buf))[:int(buf_n):int(buf_n)])
Expand All @@ -72,20 +73,20 @@ func go_iop_read(iid C.int, buf *C.char, buf_n C.int, out_err **C.char) C.int {
return C.int(-1)
}
} else if err != nil {
*out_err = C.CString(fmt.Sprintf("iop_read: read up to %d bytes from iid %d: %v", buf_n, int(iid), err))
*out_err = C.CString(fmt.Sprintf("go_iop_read: read up to %d bytes from iid %d: %v", buf_n, int(iid), err))
}
return C.int(n)
case nil:
*out_err = C.CString(fmt.Sprintf("iop_read: iid %d has been deleted", int(iid)))
*out_err = C.CString(fmt.Sprintf("go_iop_read: iid %d has been deleted", int(iid)))
return C.int(0)
default:
*out_err = C.CString(fmt.Sprintf("iop_read: iid %d is a %T, not an io.Reader", int(iid), i))
*out_err = C.CString(fmt.Sprintf("go_iop_read: iid %d is a %T, not an io.Reader", int(iid), i))
return C.int(0)
}
}

//export go_iop_write
func go_iop_write(iid C.int, buf *C.char, buf_n C.int, out_err **C.char) C.int {
func go_iop_write(iid C.int, buf *C.char, buf_n C.size_t, out_err **C.char) C.int {
switch i := iopGet(int(iid)).(type) {
case io.Writer:
n, err := i.Write((*[1 << 28]byte)(unsafe.Pointer(buf))[:int(buf_n):int(buf_n)])
Expand All @@ -94,14 +95,14 @@ func go_iop_write(iid C.int, buf *C.char, buf_n C.int, out_err **C.char) C.int {
return C.int(-1)
}
} else if err != nil {
*out_err = C.CString(fmt.Sprintf("iop_write: write up to %d bytes to iid %d: %v", buf_n, int(iid), err))
*out_err = C.CString(fmt.Sprintf("go_iop_write: write up to %d bytes to iid %d: %v", buf_n, int(iid), err))
}
return C.int(n)
case nil:
*out_err = C.CString(fmt.Sprintf("iop_write: iid %d has been deleted", int(iid)))
*out_err = C.CString(fmt.Sprintf("go_iop_write: iid %d has been deleted", int(iid)))
return C.int(0)
default:
*out_err = C.CString(fmt.Sprintf("iop_write: iid %d is a %T, not an io.Writer", int(iid), i))
*out_err = C.CString(fmt.Sprintf("go_iop_write: iid %d is a %T, not an io.Writer", int(iid), i))
return C.int(0)
}
}
Loading

0 comments on commit 1a80ab6

Please sign in to comment.