26 changes: 11 additions & 15 deletions src/restic/repository/index.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package repository

import (
"bytes"
"encoding/json"
"io"
"restic"
Expand All @@ -10,7 +9,6 @@ import (

"restic/errors"

"restic/crypto"
"restic/debug"
)

Expand Down Expand Up @@ -177,15 +175,15 @@ func (idx *Index) Has(id restic.ID, tpe restic.BlobType) bool {
return false
}

// LookupSize returns the length of the cleartext content behind the
// given id
func (idx *Index) LookupSize(id restic.ID, tpe restic.BlobType) (cleartextLength uint, err error) {
// LookupSize returns the length of the plaintext content of the blob with the
// given id.
func (idx *Index) LookupSize(id restic.ID, tpe restic.BlobType) (plaintextLength uint, err error) {
blobs, err := idx.Lookup(id, tpe)
if err != nil {
return 0, err
}

return blobs[0].Length - crypto.Extension, nil
return uint(restic.PlaintextLength(int(blobs[0].Length))), nil
}

// Supersedes returns the list of indexes this index supersedes, if any.
Expand Down Expand Up @@ -452,12 +450,11 @@ func isErrOldIndex(err error) bool {
var ErrOldIndexFormat = errors.New("index has old format")

// DecodeIndex loads and unserializes an index from rd.
func DecodeIndex(rd io.Reader) (idx *Index, err error) {
func DecodeIndex(buf []byte) (idx *Index, err error) {
debug.Log("Start decoding index")
idxJSON := jsonIndex{}
idxJSON := &jsonIndex{}

dec := json.NewDecoder(rd)
err = dec.Decode(&idxJSON)
err = json.Unmarshal(buf, idxJSON)
if err != nil {
debug.Log("Error %v", err)

Expand Down Expand Up @@ -491,12 +488,11 @@ func DecodeIndex(rd io.Reader) (idx *Index, err error) {
}

// DecodeOldIndex loads and unserializes an index in the old format from rd.
func DecodeOldIndex(rd io.Reader) (idx *Index, err error) {
func DecodeOldIndex(buf []byte) (idx *Index, err error) {
debug.Log("Start decoding old index")
list := []*packJSON{}

dec := json.NewDecoder(rd)
err = dec.Decode(&list)
err = json.Unmarshal(buf, &list)
if err != nil {
debug.Log("Error %#v", err)
return nil, errors.Wrap(err, "Decode")
Expand All @@ -523,15 +519,15 @@ func DecodeOldIndex(rd io.Reader) (idx *Index, err error) {
}

// LoadIndexWithDecoder loads the index and decodes it with fn.
func LoadIndexWithDecoder(repo restic.Repository, id restic.ID, fn func(io.Reader) (*Index, error)) (idx *Index, err error) {
func LoadIndexWithDecoder(repo restic.Repository, id restic.ID, fn func([]byte) (*Index, error)) (idx *Index, err error) {
debug.Log("Loading index %v", id.Str())

buf, err := repo.LoadAndDecrypt(restic.IndexFile, id)
if err != nil {
return nil, err
}

idx, err = fn(bytes.NewReader(buf))
idx, err = fn(buf)
if err != nil {
debug.Log("error while decoding index %v: %v", id, err)
return nil, err
Expand Down
3 changes: 2 additions & 1 deletion src/restic/repository/index_rebuild.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ func RebuildIndex(repo restic.Repository) error {
debug.Log("new index saved as %v", id.Str())

for indexID := range oldIndexes {
err := repo.Backend().Remove(restic.IndexFile, indexID.String())
h := restic.Handle{Type: restic.IndexFile, Name: indexID.String()}
err := repo.Backend().Remove(h)
if err != nil {
fmt.Fprintf(os.Stderr, "unable to remove index %v: %v\n", indexID.Str(), err)
}
Expand Down
17 changes: 13 additions & 4 deletions src/restic/repository/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func TestIndexSerialize(t *testing.T) {
err := idx.Encode(wr)
OK(t, err)

idx2, err := repository.DecodeIndex(wr)
idx2, err := repository.DecodeIndex(wr.Bytes())
OK(t, err)
Assert(t, idx2 != nil,
"nil returned for decoded index")
Expand Down Expand Up @@ -136,7 +136,7 @@ func TestIndexSerialize(t *testing.T) {
Assert(t, id2.Equal(id),
"wrong ID returned: want %v, got %v", id, id2)

idx3, err := repository.DecodeIndex(wr3)
idx3, err := repository.DecodeIndex(wr3.Bytes())
OK(t, err)
Assert(t, idx3 != nil,
"nil returned for decoded index")
Expand Down Expand Up @@ -288,7 +288,7 @@ var exampleLookupTest = struct {
func TestIndexUnserialize(t *testing.T) {
oldIdx := restic.IDs{restic.TestParseID("ed54ae36197f4745ebc4b54d10e0f623eaaaedd03013eb7ae90df881b7781452")}

idx, err := repository.DecodeIndex(bytes.NewReader(docExample))
idx, err := repository.DecodeIndex(docExample)
OK(t, err)

for _, test := range exampleTests {
Expand Down Expand Up @@ -326,8 +326,17 @@ func TestIndexUnserialize(t *testing.T) {
}
}

func BenchmarkDecodeIndex(b *testing.B) {
b.ResetTimer()

for i := 0; i < b.N; i++ {
_, err := repository.DecodeIndex(docExample)
OK(b, err)
}
}

func TestIndexUnserializeOld(t *testing.T) {
idx, err := repository.DecodeOldIndex(bytes.NewReader(docOldExample))
idx, err := repository.DecodeOldIndex(docOldExample)
OK(t, err)

for _, test := range exampleTests {
Expand Down
5 changes: 3 additions & 2 deletions src/restic/repository/key.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package repository

import (
"bytes"
"encoding/json"
"fmt"
"os"
Expand Down Expand Up @@ -146,7 +147,7 @@ func SearchKey(s *Repository, password string, maxKeys int) (*Key, error) {
// LoadKey loads a key from the backend.
func LoadKey(s *Repository, name string) (k *Key, err error) {
h := restic.Handle{Type: restic.KeyFile, Name: name}
data, err := backend.LoadAll(s.be, h, nil)
data, err := backend.LoadAll(s.be, h)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -232,7 +233,7 @@ func AddKey(s *Repository, password string, template *crypto.Key) (*Key, error)
Name: restic.Hash(buf).String(),
}

err = s.be.Save(h, buf)
err = s.be.Save(h, bytes.NewReader(buf))
if err != nil {
return nil, err
}
Expand Down
8 changes: 3 additions & 5 deletions src/restic/repository/master_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ func (mi *MasterIndex) Lookup(id restic.ID, tpe restic.BlobType) (blobs []restic
for _, idx := range mi.idx {
blobs, err = idx.Lookup(id, tpe)
if err == nil {
debug.Log("MasterIndex.Lookup",
"found id %v: %v", id.Str(), blobs)
debug.Log("found id %v: %v", id.Str(), blobs)
return
}
}
Expand All @@ -46,9 +45,8 @@ func (mi *MasterIndex) LookupSize(id restic.ID, tpe restic.BlobType) (uint, erro
defer mi.idxMutex.RUnlock()

for _, idx := range mi.idx {
length, err := idx.LookupSize(id, tpe)
if err == nil {
return length, nil
if idx.Has(id, tpe) {
return idx.LookupSize(id, tpe)
}
}

Expand Down
85 changes: 46 additions & 39 deletions src/restic/repository/packer_manager.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package repository

import (
"crypto/sha256"
"io"
"io/ioutil"
"os"
"restic"
"sync"

"restic/errors"
"restic/hashing"

"restic/crypto"
"restic/debug"
Expand All @@ -17,15 +19,22 @@ import (

// Saver implements saving data in a backend.
type Saver interface {
Save(h restic.Handle, jp []byte) error
Save(restic.Handle, io.Reader) error
}

// Packer holds a pack.Packer together with a hash writer.
type Packer struct {
*pack.Packer
hw *hashing.Writer
tmpfile *os.File
}

// packerManager keeps a list of open packs and creates new on demand.
type packerManager struct {
be Saver
key *crypto.Key
pm sync.Mutex
packs []*pack.Packer
be Saver
key *crypto.Key
pm sync.Mutex
packers []*Packer

pool sync.Pool
}
Expand All @@ -50,18 +59,18 @@ func newPackerManager(be Saver, key *crypto.Key) *packerManager {

// findPacker returns a packer for a new blob of size bytes. Either a new one is
// created or one is returned that already has some blobs.
func (r *packerManager) findPacker(size uint) (packer *pack.Packer, err error) {
func (r *packerManager) findPacker(size uint) (packer *Packer, err error) {
r.pm.Lock()
defer r.pm.Unlock()

// search for a suitable packer
if len(r.packs) > 0 {
if len(r.packers) > 0 {
debug.Log("searching packer for %d bytes\n", size)
for i, p := range r.packs {
if p.Size()+size < maxPackSize {
for i, p := range r.packers {
if p.Packer.Size()+size < maxPackSize {
debug.Log("found packer %v", p)
// remove from list
r.packs = append(r.packs[:i], r.packs[i+1:]...)
r.packers = append(r.packers[:i], r.packers[i+1:]...)
return p, nil
}
}
Expand All @@ -74,64 +83,62 @@ func (r *packerManager) findPacker(size uint) (packer *pack.Packer, err error) {
return nil, errors.Wrap(err, "ioutil.TempFile")
}

return pack.NewPacker(r.key, tmpfile), nil
hw := hashing.NewWriter(tmpfile, sha256.New())
p := pack.NewPacker(r.key, hw)
packer = &Packer{
Packer: p,
hw: hw,
tmpfile: tmpfile,
}

return packer, nil
}

// insertPacker appends p to s.packs.
func (r *packerManager) insertPacker(p *pack.Packer) {
func (r *packerManager) insertPacker(p *Packer) {
r.pm.Lock()
defer r.pm.Unlock()

r.packs = append(r.packs, p)
debug.Log("%d packers\n", len(r.packs))
r.packers = append(r.packers, p)
debug.Log("%d packers\n", len(r.packers))
}

// savePacker stores p in the backend.
func (r *Repository) savePacker(p *pack.Packer) error {
debug.Log("save packer with %d blobs\n", p.Count())
n, err := p.Finalize()
func (r *Repository) savePacker(p *Packer) error {
debug.Log("save packer with %d blobs\n", p.Packer.Count())
_, err := p.Packer.Finalize()
if err != nil {
return err
}

tmpfile := p.Writer().(*os.File)
f, err := fs.Open(tmpfile.Name())
if err != nil {
return errors.Wrap(err, "Open")
}

data := make([]byte, n)
m, err := io.ReadFull(f, data)
_, err = p.tmpfile.Seek(0, 0)
if err != nil {
return errors.Wrap(err, "ReadFul")
return errors.Wrap(err, "Seek")
}

if uint(m) != n {
return errors.Errorf("read wrong number of bytes from %v: want %v, got %v", tmpfile.Name(), n, m)
}

if err = f.Close(); err != nil {
return errors.Wrap(err, "Close")
}

id := restic.Hash(data)
id := restic.IDFromHash(p.hw.Sum(nil))
h := restic.Handle{Type: restic.DataFile, Name: id.String()}

err = r.be.Save(h, data)
err = r.be.Save(h, p.tmpfile)
if err != nil {
debug.Log("Save(%v) error: %v", h, err)
return err
}

debug.Log("saved as %v", h)

err = fs.Remove(tmpfile.Name())
err = p.tmpfile.Close()
if err != nil {
return errors.Wrap(err, "close tempfile")
}

err = fs.Remove(p.tmpfile.Name())
if err != nil {
return errors.Wrap(err, "Remove")
}

// update blobs in the index
for _, b := range p.Blobs() {
for _, b := range p.Packer.Blobs() {
debug.Log(" updating blob %v to pack %v", b.ID.Str(), id.Str())
r.idx.Store(restic.PackedBlob{
Blob: restic.Blob{
Expand All @@ -152,5 +159,5 @@ func (r *packerManager) countPacker() int {
r.pm.Lock()
defer r.pm.Unlock()

return len(r.packs)
return len(r.packers)
}
61 changes: 24 additions & 37 deletions src/restic/repository/packer_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"restic"
"restic/backend/mem"
"restic/crypto"
"restic/mock"
"testing"
)

Expand Down Expand Up @@ -46,32 +47,19 @@ func randomID(rd io.Reader) restic.ID {

const maxBlobSize = 1 << 20

func saveFile(t testing.TB, be Saver, filename string, n int) {
f, err := os.Open(filename)
if err != nil {
t.Fatal(err)
}

data := make([]byte, n)
m, err := io.ReadFull(f, data)

if m != n {
t.Fatalf("read wrong number of bytes from %v: want %v, got %v", filename, m, n)
}
func saveFile(t testing.TB, be Saver, f *os.File, id restic.ID) {
h := restic.Handle{Type: restic.DataFile, Name: id.String()}
t.Logf("save file %v", h)

if err = f.Close(); err != nil {
if err := be.Save(h, f); err != nil {
t.Fatal(err)
}

h := restic.Handle{Type: restic.DataFile, Name: restic.Hash(data).String()}

err = be.Save(h, data)
if err != nil {
if err := f.Close(); err != nil {
t.Fatal(err)
}

err = os.Remove(filename)
if err != nil {
if err := os.Remove(f.Name()); err != nil {
t.Fatal(err)
}
}
Expand Down Expand Up @@ -105,41 +93,39 @@ func fillPacks(t testing.TB, rnd *randReader, be Saver, pm *packerManager, buf [
continue
}

bytesWritten, err := packer.Finalize()
_, err = packer.Finalize()
if err != nil {
t.Fatal(err)
}

tmpfile := packer.Writer().(*os.File)
saveFile(t, be, tmpfile.Name(), int(bytesWritten))
if _, err = packer.tmpfile.Seek(0, 0); err != nil {
t.Fatal(err)
}

packID := restic.IDFromHash(packer.hw.Sum(nil))
saveFile(t, be, packer.tmpfile, packID)
}

return bytes
}

func flushRemainingPacks(t testing.TB, rnd *randReader, be Saver, pm *packerManager) (bytes int) {
if pm.countPacker() > 0 {
for _, packer := range pm.packs {
for _, packer := range pm.packers {
n, err := packer.Finalize()
if err != nil {
t.Fatal(err)
}
bytes += int(n)

tmpfile := packer.Writer().(*os.File)
saveFile(t, be, tmpfile.Name(), bytes)
packID := restic.IDFromHash(packer.hw.Sum(nil))
saveFile(t, be, packer.tmpfile, packID)
}
}

return bytes
}

type fakeBackend struct{}

func (f *fakeBackend) Save(h restic.Handle, data []byte) error {
return nil
}

func TestPackerManager(t *testing.T) {
rnd := newRandReader(rand.NewSource(23))

Expand All @@ -157,17 +143,18 @@ func TestPackerManager(t *testing.T) {
func BenchmarkPackerManager(t *testing.B) {
rnd := newRandReader(rand.NewSource(23))

be := &fakeBackend{}
pm := newPackerManager(be, crypto.NewRandomKey())
be := &mock.Backend{
SaveFn: func(restic.Handle, io.Reader) error { return nil },
}
blobBuf := make([]byte, maxBlobSize)

t.ResetTimer()

bytes := 0
for i := 0; i < t.N; i++ {
bytes := 0
pm := newPackerManager(be, crypto.NewRandomKey())
bytes += fillPacks(t, rnd, be, pm, blobBuf)
bytes += flushRemainingPacks(t, rnd, be, pm)
t.Logf("saved %d bytes", bytes)
}

bytes += flushRemainingPacks(t, rnd, be, pm)
t.Logf("saved %d bytes", bytes)
}
82 changes: 63 additions & 19 deletions src/restic/repository/repack.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package repository

import (
"bytes"
"crypto/sha256"
"io"
"io/ioutil"
"os"
"restic"
"restic/crypto"
"restic/debug"
"restic/hashing"
"restic/pack"

"restic/errors"
Expand All @@ -18,30 +21,47 @@ import (
func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet) (err error) {
debug.Log("repacking %d packs while keeping %d blobs", len(packs), len(keepBlobs))

buf := make([]byte, 0, maxPackSize)
for packID := range packs {
// load the complete pack
// load the complete pack into a temp file
h := restic.Handle{Type: restic.DataFile, Name: packID.String()}

l, err := repo.Backend().Load(h, buf[:cap(buf)], 0)
if errors.Cause(err) == io.ErrUnexpectedEOF {
err = nil
buf = buf[:l]
tempfile, err := ioutil.TempFile("", "restic-temp-repack-")
if err != nil {
return errors.Wrap(err, "TempFile")
}

beRd, err := repo.Backend().Load(h, 0, 0)
if err != nil {
return err
}

debug.Log("pack %v loaded (%d bytes)", packID.Str(), len(buf))
defer beRd.Close()

blobs, err := pack.List(repo.Key(), bytes.NewReader(buf), int64(len(buf)))
hrd := hashing.NewReader(beRd, sha256.New())
packLength, err := io.Copy(tempfile, hrd)
if err != nil {
return errors.Wrap(err, "Copy")
}

hash := restic.IDFromHash(hrd.Sum(nil))
debug.Log("pack %v loaded (%d bytes), hash %v", packID.Str(), packLength, hash.Str())

if !packID.Equal(hash) {
return errors.Errorf("hash does not match id: want %v, got %v", packID, hash)
}

_, err = tempfile.Seek(0, 0)
if err != nil {
return errors.Wrap(err, "Seek")
}

blobs, err := pack.List(repo.Key(), tempfile, packLength)
if err != nil {
return err
}

debug.Log("processing pack %v, blobs: %v", packID.Str(), len(blobs))
var plaintext []byte
var buf []byte
for _, entry := range blobs {
h := restic.BlobHandle{ID: entry.ID, Type: entry.Type}
if !keepBlobs.Has(h) {
Expand All @@ -50,21 +70,36 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet

debug.Log(" process blob %v", h)

ciphertext := buf[entry.Offset : entry.Offset+entry.Length]
plaintext = plaintext[:len(plaintext)]
if len(plaintext) < len(ciphertext) {
plaintext = make([]byte, len(ciphertext))
buf = buf[:len(buf)]
if uint(len(buf)) < entry.Length {
buf = make([]byte, entry.Length)
}
buf = buf[:entry.Length]

n, err := tempfile.ReadAt(buf, int64(entry.Offset))
if err != nil {
return errors.Wrap(err, "ReadAt")
}

debug.Log(" ciphertext %d, plaintext %d", len(plaintext), len(ciphertext))
if n != len(buf) {
return errors.Errorf("read blob %v from %v: not enough bytes read, want %v, got %v",
h, tempfile.Name(), len(buf), n)
}

n, err := crypto.Decrypt(repo.Key(), plaintext, ciphertext)
n, err = crypto.Decrypt(repo.Key(), buf, buf)
if err != nil {
return err
}
plaintext = plaintext[:n]

_, err = repo.SaveBlob(entry.Type, plaintext, entry.ID)
buf = buf[:n]

id := restic.Hash(buf)
if !id.Equal(entry.ID) {
return errors.Errorf("read blob %v from %v: wrong data returned, hash is %v",
h, tempfile.Name(), id)
}

_, err = repo.SaveBlob(entry.Type, buf, entry.ID)
if err != nil {
return err
}
Expand All @@ -73,14 +108,23 @@ func Repack(repo restic.Repository, packs restic.IDSet, keepBlobs restic.BlobSet

keepBlobs.Delete(h)
}

if err = tempfile.Close(); err != nil {
return errors.Wrap(err, "Close")
}

if err = os.Remove(tempfile.Name()); err != nil {
return errors.Wrap(err, "Remove")
}
}

if err := repo.Flush(); err != nil {
return err
}

for packID := range packs {
err := repo.Backend().Remove(restic.DataFile, packID.String())
h := restic.Handle{Type: restic.DataFile, Name: packID.String()}
err := repo.Backend().Remove(h)
if err != nil {
debug.Log("error removing pack %v: %v", packID.Str(), err)
return err
Expand Down
57 changes: 26 additions & 31 deletions src/restic/repository/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (r *Repository) LoadAndDecrypt(t restic.FileType, id restic.ID) ([]byte, er
debug.Log("load %v with id %v", t, id.Str())

h := restic.Handle{Type: t, Name: id.String()}
buf, err := backend.LoadAll(r.be, h, nil)
buf, err := backend.LoadAll(r.be, h)
if err != nil {
debug.Log("error loading %v: %v", id.Str(), err)
return nil, err
Expand All @@ -64,33 +64,20 @@ func (r *Repository) LoadAndDecrypt(t restic.FileType, id restic.ID) ([]byte, er
return nil, errors.New("invalid data returned")
}

plain := make([]byte, len(buf))

// decrypt
n, err := r.decryptTo(plain, buf)
n, err := r.decryptTo(buf, buf)
if err != nil {
return nil, err
}

return plain[:n], nil
return buf[:n], nil
}

// loadBlob tries to load and decrypt content identified by t and id from a
// pack from the backend, the result is stored in plaintextBuf, which must be
// large enough to hold the complete blob.
func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []byte) (int, error) {
debug.Log("load %v with id %v (buf %p, len %d)", t, id.Str(), plaintextBuf, len(plaintextBuf))

// lookup plaintext size of blob
size, err := r.idx.LookupSize(id, t)
if err != nil {
return 0, err
}

// make sure the plaintext buffer is large enough, extend otherwise
if len(plaintextBuf) < int(size) {
return 0, errors.Errorf("buffer is too small: %d < %d", len(plaintextBuf), size)
}
debug.Log("load %v with id %v (buf len %v, cap %d)", t, id.Str(), len(plaintextBuf), cap(plaintextBuf))

// lookup packs
blobs, err := r.idx.Lookup(id, t)
Expand All @@ -109,8 +96,14 @@ func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []by

// load blob from pack
h := restic.Handle{Type: restic.DataFile, Name: blob.PackID.String()}
ciphertextBuf := make([]byte, blob.Length)
n, err := r.be.Load(h, ciphertextBuf, int64(blob.Offset))

if uint(cap(plaintextBuf)) < blob.Length {
return 0, errors.Errorf("buffer is too small: %v < %v", cap(plaintextBuf), blob.Length)
}

plaintextBuf = plaintextBuf[:blob.Length]

n, err := restic.ReadAt(r.be, h, int64(blob.Offset), plaintextBuf)
if err != nil {
debug.Log("error loading blob %v: %v", blob, err)
lastError = err
Expand All @@ -125,7 +118,7 @@ func (r *Repository) loadBlob(id restic.ID, t restic.BlobType, plaintextBuf []by
}

// decrypt
n, err = r.decryptTo(plaintextBuf, ciphertextBuf)
n, err = r.decryptTo(plaintextBuf, plaintextBuf)
if err != nil {
lastError = errors.Errorf("decrypting blob %v failed: %v", id, err)
continue
Expand Down Expand Up @@ -224,7 +217,7 @@ func (r *Repository) SaveJSONUnpacked(t restic.FileType, item interface{}) (rest
// SaveUnpacked encrypts data and stores it in the backend. Returned is the
// storage hash.
func (r *Repository) SaveUnpacked(t restic.FileType, p []byte) (id restic.ID, err error) {
ciphertext := make([]byte, len(p)+crypto.Extension)
ciphertext := restic.NewBlobBuffer(len(p))
ciphertext, err = r.Encrypt(ciphertext, p)
if err != nil {
return restic.ID{}, err
Expand All @@ -233,7 +226,7 @@ func (r *Repository) SaveUnpacked(t restic.FileType, p []byte) (id restic.ID, er
id = restic.Hash(ciphertext)
h := restic.Handle{Type: t, Name: id.String()}

err = r.be.Save(h, ciphertext)
err = r.be.Save(h, bytes.NewReader(ciphertext))
if err != nil {
debug.Log("error saving blob %v: %v", h, err)
return restic.ID{}, err
Expand All @@ -248,15 +241,15 @@ func (r *Repository) Flush() error {
r.pm.Lock()
defer r.pm.Unlock()

debug.Log("manually flushing %d packs", len(r.packs))
debug.Log("manually flushing %d packs", len(r.packerManager.packers))

for _, p := range r.packs {
for _, p := range r.packerManager.packers {
err := r.savePacker(p)
if err != nil {
return err
}
}
r.packs = r.packs[:0]
r.packerManager.packers = r.packerManager.packers[:0]
return nil
}

Expand Down Expand Up @@ -387,7 +380,7 @@ func (r *Repository) SearchKey(password string, maxKeys int) error {
// Init creates a new master key with the supplied password, initializes and
// saves the repository config.
func (r *Repository) Init(password string) error {
has, err := r.be.Test(restic.ConfigFile, "")
has, err := r.be.Test(restic.Handle{Type: restic.ConfigFile})
if err != nil {
return err
}
Expand Down Expand Up @@ -528,16 +521,18 @@ func (r *Repository) Close() error {
return r.be.Close()
}

// LoadBlob loads a blob of type t from the repository to the buffer.
// LoadBlob loads a blob of type t from the repository to the buffer. buf must
// be large enough to hold the encrypted blob, since it is used as scratch
// space.
func (r *Repository) LoadBlob(t restic.BlobType, id restic.ID, buf []byte) (int, error) {
debug.Log("load blob %v into buf %p", id.Str(), buf)
debug.Log("load blob %v into buf (len %v, cap %v)", id.Str(), len(buf), cap(buf))
size, err := r.idx.LookupSize(id, t)
if err != nil {
return 0, err
}

if len(buf) < int(size) {
return 0, errors.Errorf("buffer is too small for data blob (%d < %d)", len(buf), size)
if cap(buf) < restic.CiphertextLength(int(size)) {
return 0, errors.Errorf("buffer is too small for data blob (%d < %d)", cap(buf), restic.CiphertextLength(int(size)))
}

n, err := r.loadBlob(id, t, buf)
Expand Down Expand Up @@ -571,7 +566,7 @@ func (r *Repository) LoadTree(id restic.ID) (*restic.Tree, error) {
}

debug.Log("size is %d, create buffer", size)
buf := make([]byte, size)
buf := restic.NewBlobBuffer(int(size))

n, err := r.loadBlob(id, restic.TreeBlob, buf)
if err != nil {
Expand Down
158 changes: 145 additions & 13 deletions src/restic/repository/repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ package repository_test

import (
"bytes"
"crypto/rand"
"crypto/sha256"
"io"
mrand "math/rand"
"math/rand"
"path/filepath"
"testing"
"time"

"restic"
"restic/archiver"
Expand All @@ -17,13 +17,15 @@ import (

var testSizes = []int{5, 23, 2<<18 + 23, 1 << 20}

var rnd = rand.New(rand.NewSource(time.Now().UnixNano()))

func TestSave(t *testing.T) {
repo, cleanup := repository.TestRepository(t)
defer cleanup()

for _, size := range testSizes {
data := make([]byte, size)
_, err := io.ReadFull(rand.Reader, data)
_, err := io.ReadFull(rnd, data)
OK(t, err)

id := restic.Hash(data)
Expand All @@ -38,7 +40,7 @@ func TestSave(t *testing.T) {
// OK(t, repo.SaveIndex())

// read back
buf := make([]byte, size)
buf := restic.NewBlobBuffer(size)
n, err := repo.LoadBlob(restic.DataBlob, id, buf)
OK(t, err)
Equals(t, len(buf), n)
Expand All @@ -59,7 +61,7 @@ func TestSaveFrom(t *testing.T) {

for _, size := range testSizes {
data := make([]byte, size)
_, err := io.ReadFull(rand.Reader, data)
_, err := io.ReadFull(rnd, data)
OK(t, err)

id := restic.Hash(data)
Expand All @@ -72,7 +74,7 @@ func TestSaveFrom(t *testing.T) {
OK(t, repo.Flush())

// read back
buf := make([]byte, size)
buf := restic.NewBlobBuffer(size)
n, err := repo.LoadBlob(restic.DataBlob, id, buf)
OK(t, err)
Equals(t, len(buf), n)
Expand All @@ -94,7 +96,7 @@ func BenchmarkSaveAndEncrypt(t *testing.B) {
size := 4 << 20 // 4MiB

data := make([]byte, size)
_, err := io.ReadFull(rand.Reader, data)
_, err := io.ReadFull(rnd, data)
OK(t, err)

id := restic.ID(sha256.Sum256(data))
Expand Down Expand Up @@ -145,6 +147,113 @@ func BenchmarkLoadTree(t *testing.B) {
}
}

func TestLoadBlob(t *testing.T) {
repo, cleanup := repository.TestRepository(t)
defer cleanup()

length := 1000000
buf := restic.NewBlobBuffer(length)
_, err := io.ReadFull(rnd, buf)
OK(t, err)

id, err := repo.SaveBlob(restic.DataBlob, buf, restic.ID{})
OK(t, err)
OK(t, repo.Flush())

// first, test with buffers that are too small
for _, testlength := range []int{length - 20, length, restic.CiphertextLength(length) - 1} {
buf = make([]byte, 0, testlength)
n, err := repo.LoadBlob(restic.DataBlob, id, buf)
if err == nil {
t.Errorf("LoadBlob() did not return an error for a buffer that is too small to hold the blob")
continue
}

if n != 0 {
t.Errorf("LoadBlob() returned an error and n > 0")
continue
}
}

// then use buffers that are large enough
base := restic.CiphertextLength(length)
for _, testlength := range []int{base, base + 7, base + 15, base + 1000} {
buf = make([]byte, 0, testlength)
n, err := repo.LoadBlob(restic.DataBlob, id, buf)
if err != nil {
t.Errorf("LoadBlob() returned an error for buffer size %v: %v", testlength, err)
continue
}

if n != length {
t.Errorf("LoadBlob() returned the wrong number of bytes: want %v, got %v", length, n)
continue
}
}
}

func BenchmarkLoadBlob(b *testing.B) {
repo, cleanup := repository.TestRepository(b)
defer cleanup()

length := 1000000
buf := restic.NewBlobBuffer(length)
_, err := io.ReadFull(rnd, buf)
OK(b, err)

id, err := repo.SaveBlob(restic.DataBlob, buf, restic.ID{})
OK(b, err)
OK(b, repo.Flush())

b.ResetTimer()
b.SetBytes(int64(length))

for i := 0; i < b.N; i++ {
n, err := repo.LoadBlob(restic.DataBlob, id, buf)
OK(b, err)
if n != length {
b.Errorf("wanted %d bytes, got %d", length, n)
}

id2 := restic.Hash(buf[:n])
if !id.Equal(id2) {
b.Errorf("wrong data returned, wanted %v, got %v", id.Str(), id2.Str())
}
}
}

func BenchmarkLoadAndDecrypt(b *testing.B) {
repo, cleanup := repository.TestRepository(b)
defer cleanup()

length := 1000000
buf := restic.NewBlobBuffer(length)
_, err := io.ReadFull(rnd, buf)
OK(b, err)

dataID := restic.Hash(buf)

storageID, err := repo.SaveUnpacked(restic.DataFile, buf)
OK(b, err)
// OK(b, repo.Flush())

b.ResetTimer()
b.SetBytes(int64(length))

for i := 0; i < b.N; i++ {
data, err := repo.LoadAndDecrypt(restic.DataFile, storageID)
OK(b, err)
if len(data) != length {
b.Errorf("wanted %d bytes, got %d", length, len(data))
}

id2 := restic.Hash(data)
if !dataID.Equal(id2) {
b.Errorf("wrong data returned, wanted %v, got %v", storageID.Str(), id2.Str())
}
}
}

func TestLoadJSONUnpacked(t *testing.T) {
repo, cleanup := repository.TestRepository(t)
defer cleanup()
Expand Down Expand Up @@ -182,25 +291,48 @@ func TestRepositoryLoadIndex(t *testing.T) {
}

func BenchmarkLoadIndex(b *testing.B) {
repodir, cleanup := Env(b, repoFixture)
repository.TestUseLowSecurityKDFParameters(b)

repo, cleanup := repository.TestRepository(b)
defer cleanup()

repo := repository.TestOpenLocal(b, repodir)
idx := repository.NewIndex()

for i := 0; i < 5000; i++ {
idx.Store(restic.PackedBlob{
Blob: restic.Blob{
Type: restic.DataBlob,
Length: 1234,
ID: restic.NewRandomID(),
Offset: 1235,
},
PackID: restic.NewRandomID(),
})
}

id, err := repository.SaveIndex(repo, idx)
OK(b, err)

b.Logf("index saved as %v (%v entries)", id.Str(), idx.Count(restic.DataBlob))
fi, err := repo.Backend().Stat(restic.Handle{Type: restic.IndexFile, Name: id.String()})
OK(b, err)
b.Logf("filesize is %v", fi.Size)

b.ResetTimer()

for i := 0; i < b.N; i++ {
repo.SetIndex(repository.NewMasterIndex())
OK(b, repo.LoadIndex())
_, err := repository.LoadIndex(repo, id)
OK(b, err)
}
}

// saveRandomDataBlobs generates random data blobs and saves them to the repository.
func saveRandomDataBlobs(t testing.TB, repo restic.Repository, num int, sizeMax int) {
for i := 0; i < num; i++ {
size := mrand.Int() % sizeMax
size := rand.Int() % sizeMax

buf := make([]byte, size)
_, err := io.ReadFull(rand.Reader, buf)
_, err := io.ReadFull(rnd, buf)
OK(t, err)

_, err = repo.SaveBlob(restic.DataBlob, buf, restic.ID{})
Expand Down
2 changes: 1 addition & 1 deletion src/restic/test/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func Random(seed, count int) []byte {

for j := range data {
cur := i + j
if len(p) >= cur {
if cur >= len(p) {
break
}
p[cur] = data[j]
Expand Down
25 changes: 17 additions & 8 deletions src/restic/testing.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,26 @@ type fakeFileSystem struct {
repo Repository
knownBlobs IDSet
duplication float32
buf []byte
chunker *chunker.Chunker
}

// saveFile reads from rd and saves the blobs in the repository. The list of
// IDs is returned.
func (fs fakeFileSystem) saveFile(rd io.Reader) (blobs IDs) {
blobs = IDs{}
ch := chunker.New(rd, fs.repo.Config().ChunkerPolynomial)
func (fs *fakeFileSystem) saveFile(rd io.Reader) (blobs IDs) {
if fs.buf == nil {
fs.buf = make([]byte, chunker.MaxSize)
}

if fs.chunker == nil {
fs.chunker = chunker.New(rd, fs.repo.Config().ChunkerPolynomial)
} else {
fs.chunker.Reset(rd, fs.repo.Config().ChunkerPolynomial)
}

blobs = IDs{}
for {
chunk, err := ch.Next(getBuf())
chunk, err := fs.chunker.Next(fs.buf)
if errors.Cause(err) == io.EOF {
break
}
Expand All @@ -50,7 +60,6 @@ func (fs fakeFileSystem) saveFile(rd io.Reader) (blobs IDs) {

fs.knownBlobs.Insert(id)
}
freeBuf(chunk.Data)

blobs = append(blobs, id)
}
Expand All @@ -64,7 +73,7 @@ const (
maxNodes = 32
)

func (fs fakeFileSystem) treeIsKnown(tree *Tree) (bool, []byte, ID) {
func (fs *fakeFileSystem) treeIsKnown(tree *Tree) (bool, []byte, ID) {
data, err := json.Marshal(tree)
if err != nil {
fs.t.Fatalf("json.Marshal(tree) returned error: %v", err)
Expand All @@ -76,7 +85,7 @@ func (fs fakeFileSystem) treeIsKnown(tree *Tree) (bool, []byte, ID) {
return fs.blobIsKnown(id, TreeBlob), data, id
}

func (fs fakeFileSystem) blobIsKnown(id ID, t BlobType) bool {
func (fs *fakeFileSystem) blobIsKnown(id ID, t BlobType) bool {
if rand.Float32() < fs.duplication {
return false
}
Expand All @@ -94,7 +103,7 @@ func (fs fakeFileSystem) blobIsKnown(id ID, t BlobType) bool {
}

// saveTree saves a tree of fake files in the repo and returns the ID.
func (fs fakeFileSystem) saveTree(seed int64, depth int) ID {
func (fs *fakeFileSystem) saveTree(seed int64, depth int) ID {
rnd := rand.NewSource(seed)
numNodes := int(rnd.Int63() % maxNodes)

Expand Down
11 changes: 11 additions & 0 deletions src/restic/testing_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,14 @@ func TestCreateSnapshot(t *testing.T) {

checker.TestCheckRepo(t, repo)
}

func BenchmarkTestCreateSnapshot(t *testing.B) {
repo, cleanup := repository.TestRepository(t)
defer cleanup()

t.ResetTimer()

for i := 0; i < t.N; i++ {
restic.TestCreateSnapshot(t, repo, testSnapshotTime.Add(time.Duration(i)*time.Second), testDepth, 0)
}
}
6 changes: 6 additions & 0 deletions vendor/manifest
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
"revision": "17b591df37844cde689f4d5813e5cea0927d8dd2",
"branch": "master"
},
{
"importpath": "github.com/pkg/profile",
"repository": "https://github.com/pkg/profile",
"revision": "1c16f117a3ab788fdf0e334e623b8bccf5679866",
"branch": "HEAD"
},
{
"importpath": "github.com/pkg/sftp",
"repository": "https://github.com/pkg/sftp",
Expand Down
1 change: 1 addition & 0 deletions vendor/src/github.com/pkg/profile/AUTHORS
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Dave Cheney <dave@cheney.net>
24 changes: 24 additions & 0 deletions vendor/src/github.com/pkg/profile/LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Copyright (c) 2013 Dave Cheney. All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
47 changes: 47 additions & 0 deletions vendor/src/github.com/pkg/profile/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
profile
=======

Simple profiling support package for Go

[![Build Status](https://travis-ci.org/pkg/profile.svg?branch=master)](https://travis-ci.org/pkg/profile) [![GoDoc](http://godoc.org/github.com/pkg/profile?status.svg)](http://godoc.org/github.com/pkg/profile)


installation
------------

go get github.com/pkg/profile

usage
-----

Enabling profiling in your application is as simple as one line at the top of your main function

```go
import "github.com/pkg/profile"

func main() {
defer profile.Start().Stop()
...
}
```

options
-------

What to profile is controlled by config value passed to profile.Start.
By default CPU profiling is enabled.

```go
import "github.com/pkg/profile"

func main() {
// p.Stop() must be called before the program exits to
// ensure profiling information is written to disk.
p := profile.Start(profile.MemProfile, profile.ProfilePath("."), profile.NoShutdownHook)
...
}
```

Several convenience package level values are provided for cpu, memory, and block (contention) profiling.

For more complex options, consult the [documentation](http://godoc.org/github.com/pkg/profile).
56 changes: 56 additions & 0 deletions vendor/src/github.com/pkg/profile/example_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package profile_test

import (
"flag"
"os"

"github.com/pkg/profile"
)

func ExampleStart() {
// start a simple CPU profile and register
// a defer to Stop (flush) the profiling data.
defer profile.Start().Stop()
}

func ExampleCPUProfile() {
// CPU profiling is the default profiling mode, but you can specify it
// explicitly for completeness.
defer profile.Start(profile.CPUProfile).Stop()
}

func ExampleMemProfile() {
// use memory profiling, rather than the default cpu profiling.
defer profile.Start(profile.MemProfile).Stop()
}

func ExampleMemProfileRate() {
// use memory profiling with custom rate.
defer profile.Start(profile.MemProfileRate(2048)).Stop()
}

func ExampleProfilePath() {
// set the location that the profile will be written to
defer profile.Start(profile.ProfilePath(os.Getenv("HOME")))
}

func ExampleNoShutdownHook() {
// disable the automatic shutdown hook.
defer profile.Start(profile.NoShutdownHook).Stop()
}

func ExampleStart_withFlags() {
// use the flags package to selectively enable profiling.
mode := flag.String("profile.mode", "", "enable profiling mode, one of [cpu, mem, block]")
flag.Parse()
switch *mode {
case "cpu":
defer profile.Start(profile.CPUProfile).Stop()
case "mem":
defer profile.Start(profile.MemProfile).Stop()
case "block":
defer profile.Start(profile.BlockProfile).Stop()
default:
// do nothing
}
}
216 changes: 216 additions & 0 deletions vendor/src/github.com/pkg/profile/profile.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Package profile provides a simple way to manage runtime/pprof
// profiling of your Go application.
package profile

import (
"io/ioutil"
"log"
"os"
"os/signal"
"path/filepath"
"runtime"
"runtime/pprof"
"sync/atomic"
)

const (
cpuMode = iota
memMode
blockMode
traceMode
)

type profile struct {
// quiet suppresses informational messages during profiling.
quiet bool

// noShutdownHook controls whether the profiling package should
// hook SIGINT to write profiles cleanly.
noShutdownHook bool

// mode holds the type of profiling that will be made
mode int

// path holds the base path where various profiling files are written.
// If blank, the base path will be generated by ioutil.TempDir.
path string

// memProfileRate holds the rate for the memory profile.
memProfileRate int

// closer holds a cleanup function that run after each profile
closer func()

// stopped records if a call to profile.Stop has been made
stopped uint32
}

// NoShutdownHook controls whether the profiling package should
// hook SIGINT to write profiles cleanly.
// Programs with more sophisticated signal handling should set
// this to true and ensure the Stop() function returned from Start()
// is called during shutdown.
func NoShutdownHook(p *profile) { p.noShutdownHook = true }

// Quiet suppresses informational messages during profiling.
func Quiet(p *profile) { p.quiet = true }

// CPUProfile enables cpu profiling.
// It disables any previous profiling settings.
func CPUProfile(p *profile) { p.mode = cpuMode }

// DefaultMemProfileRate is the default memory profiling rate.
// See also http://golang.org/pkg/runtime/#pkg-variables
const DefaultMemProfileRate = 4096

// MemProfile enables memory profiling.
// It disables any previous profiling settings.
func MemProfile(p *profile) {
p.memProfileRate = DefaultMemProfileRate
p.mode = memMode
}

// MemProfileRate enables memory profiling at the preferred rate.
// It disables any previous profiling settings.
func MemProfileRate(rate int) func(*profile) {
return func(p *profile) {
p.memProfileRate = rate
p.mode = memMode
}
}

// BlockProfile enables block (contention) profiling.
// It disables any previous profiling settings.
func BlockProfile(p *profile) { p.mode = blockMode }

// ProfilePath controls the base path where various profiling
// files are written. If blank, the base path will be generated
// by ioutil.TempDir.
func ProfilePath(path string) func(*profile) {
return func(p *profile) {
p.path = path
}
}

// Stop stops the profile and flushes any unwritten data.
func (p *profile) Stop() {
if !atomic.CompareAndSwapUint32(&p.stopped, 0, 1) {
// someone has already called close
return
}
p.closer()
atomic.StoreUint32(&started, 0)
}

// started is non zero if a profile is running.
var started uint32

// Start starts a new profiling session.
// The caller should call the Stop method on the value returned
// to cleanly stop profiling.
func Start(options ...func(*profile)) interface {
Stop()
} {
if !atomic.CompareAndSwapUint32(&started, 0, 1) {
log.Fatal("profile: Start() already called")
}

var prof profile
for _, option := range options {
option(&prof)
}

path, err := func() (string, error) {
if p := prof.path; p != "" {
return p, os.MkdirAll(p, 0777)
}
return ioutil.TempDir("", "profile")
}()

if err != nil {
log.Fatalf("profile: could not create initial output directory: %v", err)
}

logf := func(format string, args ...interface{}) {
if !prof.quiet {
log.Printf(format, args...)
}
}

switch prof.mode {
case cpuMode:
fn := filepath.Join(path, "cpu.pprof")
f, err := os.Create(fn)
if err != nil {
log.Fatalf("profile: could not create cpu profile %q: %v", fn, err)
}
logf("profile: cpu profiling enabled, %s", fn)
pprof.StartCPUProfile(f)
prof.closer = func() {
pprof.StopCPUProfile()
f.Close()
logf("profile: cpu profiling disabled, %s", fn)
}

case memMode:
fn := filepath.Join(path, "mem.pprof")
f, err := os.Create(fn)
if err != nil {
log.Fatalf("profile: could not create memory profile %q: %v", fn, err)
}
old := runtime.MemProfileRate
runtime.MemProfileRate = prof.memProfileRate
logf("profile: memory profiling enabled (rate %d), %s", runtime.MemProfileRate, fn)
prof.closer = func() {
pprof.Lookup("heap").WriteTo(f, 0)
f.Close()
runtime.MemProfileRate = old
logf("profile: memory profiling disabled, %s", fn)
}

case blockMode:
fn := filepath.Join(path, "block.pprof")
f, err := os.Create(fn)
if err != nil {
log.Fatalf("profile: could not create block profile %q: %v", fn, err)
}
runtime.SetBlockProfileRate(1)
logf("profile: block profiling enabled, %s", fn)
prof.closer = func() {
pprof.Lookup("block").WriteTo(f, 0)
f.Close()
runtime.SetBlockProfileRate(0)
logf("profile: block profiling disabled, %s", fn)
}

case traceMode:
fn := filepath.Join(path, "trace.out")
f, err := os.Create(fn)
if err != nil {
log.Fatalf("profile: could not create trace output file %q: %v", fn, err)
}
if err := startTrace(f); err != nil {
log.Fatalf("profile: could not start trace: %v", err)
}
logf("profile: trace enabled, %s", fn)
prof.closer = func() {
stopTrace()
logf("profile: trace disabled, %s", fn)
}
}

if !prof.noShutdownHook {
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
<-c

log.Println("profile: caught interrupt, stopping profiles")
prof.Stop()

os.Exit(0)
}()
}

return &prof
}
304 changes: 304 additions & 0 deletions vendor/src/github.com/pkg/profile/profile_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
package profile

import (
"bufio"
"bytes"
"io"
"io/ioutil"
"os"
"os/exec"
"path/filepath"
"strings"
"testing"
)

type checkFn func(t *testing.T, stdout, stderr []byte, err error)

var profileTests = []struct {
name string
code string
checks []checkFn
}{{
name: "default profile (cpu)",
code: `
package main
import "github.com/pkg/profile"
func main() {
defer profile.Start().Stop()
}
`,
checks: []checkFn{
NoStdout,
Stderr("profile: cpu profiling enabled"),
NoErr,
},
}, {
name: "memory profile",
code: `
package main
import "github.com/pkg/profile"
func main() {
defer profile.Start(profile.MemProfile).Stop()
}
`,
checks: []checkFn{
NoStdout,
Stderr("profile: memory profiling enabled"),
NoErr,
},
}, {
name: "memory profile (rate 2048)",
code: `
package main
import "github.com/pkg/profile"
func main() {
defer profile.Start(profile.MemProfileRate(2048)).Stop()
}
`,
checks: []checkFn{
NoStdout,
Stderr("profile: memory profiling enabled (rate 2048)"),
NoErr,
},
}, {
name: "double start",
code: `
package main
import "github.com/pkg/profile"
func main() {
profile.Start()
profile.Start()
}
`,
checks: []checkFn{
NoStdout,
Stderr("cpu profiling enabled", "profile: Start() already called"),
Err,
},
}, {
name: "block profile",
code: `
package main
import "github.com/pkg/profile"
func main() {
defer profile.Start(profile.BlockProfile).Stop()
}
`,
checks: []checkFn{
NoStdout,
Stderr("profile: block profiling enabled"),
NoErr,
},
}, {
name: "profile path",
code: `
package main
import "github.com/pkg/profile"
func main() {
defer profile.Start(profile.ProfilePath(".")).Stop()
}
`,
checks: []checkFn{
NoStdout,
Stderr("profile: cpu profiling enabled, cpu.pprof"),
NoErr,
},
}, {
name: "profile path error",
code: `
package main
import "github.com/pkg/profile"
func main() {
defer profile.Start(profile.ProfilePath("README.md")).Stop()
}
`,
checks: []checkFn{
NoStdout,
Stderr("could not create initial output"),
Err,
},
}, {
name: "multiple profile sessions",
code: `
package main
import "github.com/pkg/profile"
func main() {
profile.Start(profile.CPUProfile).Stop()
profile.Start(profile.MemProfile).Stop()
profile.Start(profile.BlockProfile).Stop()
profile.Start(profile.CPUProfile).Stop()
}
`,
checks: []checkFn{
NoStdout,
Stderr("profile: cpu profiling enabled",
"profile: cpu profiling disabled",
"profile: memory profiling enabled",
"profile: memory profiling disabled",
"profile: block profiling enabled",
"profile: block profiling disabled"),
NoErr,
},
}, {
name: "profile quiet",
code: `
package main
import "github.com/pkg/profile"
func main() {
defer profile.Start(profile.Quiet).Stop()
}
`,
checks: []checkFn{NoStdout, NoStderr, NoErr},
}}

func TestProfile(t *testing.T) {
for _, tt := range profileTests {
t.Log(tt.name)
stdout, stderr, err := runTest(t, tt.code)
for _, f := range tt.checks {
f(t, stdout, stderr, err)
}
}
}

// NoStdout checks that stdout was blank.
func NoStdout(t *testing.T, stdout, _ []byte, _ error) {
if len := len(stdout); len > 0 {
t.Errorf("stdout: wanted 0 bytes, got %d", len)
}
}

// Stderr verifies that the given lines match the output from stderr
func Stderr(lines ...string) checkFn {
return func(t *testing.T, _, stderr []byte, _ error) {
r := bytes.NewReader(stderr)
if !validateOutput(r, lines) {
t.Errorf("stderr: wanted '%s', got '%s'", lines, stderr)
}
}
}

// NoStderr checks that stderr was blank.
func NoStderr(t *testing.T, _, stderr []byte, _ error) {
if len := len(stderr); len > 0 {
t.Errorf("stderr: wanted 0 bytes, got %d", len)
}
}

// Err checks that there was an error returned
func Err(t *testing.T, _, _ []byte, err error) {
if err == nil {
t.Errorf("expected error")
}
}

// NoErr checks that err was nil
func NoErr(t *testing.T, _, _ []byte, err error) {
if err != nil {
t.Errorf("error: expected nil, got %v", err)
}
}

// validatedOutput validates the given slice of lines against data from the given reader.
func validateOutput(r io.Reader, want []string) bool {
s := bufio.NewScanner(r)
for _, line := range want {
if !s.Scan() || !strings.Contains(s.Text(), line) {
return false
}
}
return true
}

var validateOutputTests = []struct {
input string
lines []string
want bool
}{{
input: "",
want: true,
}, {
input: `profile: yes
`,
want: true,
}, {
input: `profile: yes
`,
lines: []string{"profile: yes"},
want: true,
}, {
input: `profile: yes
profile: no
`,
lines: []string{"profile: yes"},
want: true,
}, {
input: `profile: yes
profile: no
`,
lines: []string{"profile: yes", "profile: no"},
want: true,
}, {
input: `profile: yes
profile: no
`,
lines: []string{"profile: no"},
want: false,
}}

func TestValidateOutput(t *testing.T) {
for _, tt := range validateOutputTests {
r := strings.NewReader(tt.input)
got := validateOutput(r, tt.lines)
if tt.want != got {
t.Errorf("validateOutput(%q, %q), want %v, got %v", tt.input, tt.lines, tt.want, got)
}
}
}

// runTest executes the go program supplied and returns the contents of stdout,
// stderr, and an error which may contain status information about the result
// of the program.
func runTest(t *testing.T, code string) ([]byte, []byte, error) {
chk := func(err error) {
if err != nil {
t.Fatal(err)
}
}
gopath, err := ioutil.TempDir("", "profile-gopath")
chk(err)
defer os.RemoveAll(gopath)

srcdir := filepath.Join(gopath, "src")
err = os.Mkdir(srcdir, 0755)
chk(err)
src := filepath.Join(srcdir, "main.go")
err = ioutil.WriteFile(src, []byte(code), 0644)
chk(err)

cmd := exec.Command("go", "run", src)

var stdout, stderr bytes.Buffer
cmd.Stdout = &stdout
cmd.Stderr = &stderr
err = cmd.Run()
return stdout.Bytes(), stderr.Bytes(), err
}
11 changes: 11 additions & 0 deletions vendor/src/github.com/pkg/profile/trace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// +build go1.7

package profile

import "runtime/trace"

// Trace profile controls if execution tracing will be enabled. It disables any previous profiling settings.
func TraceProfile(p *profile) { p.mode = traceMode }

var startTrace = trace.Start
var stopTrace = trace.Stop
10 changes: 10 additions & 0 deletions vendor/src/github.com/pkg/profile/trace16.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// +build !go1.7

package profile

import "io"

// mock trace support for Go 1.6 and earlier.

func startTrace(w io.Writer) error { return nil }
func stopTrace() {}
10 changes: 10 additions & 0 deletions vendor/src/github.com/pkg/profile/trace_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// +build go1.7

package profile_test

import "github.com/pkg/profile"

func ExampleTraceProfile() {
// use execution tracing, rather than the default cpu profiling.
defer profile.Start(profile.TraceProfile).Stop()
}