Skip to content

Commit

Permalink
fix ensure_digest(block=True) call if no headers_buffer is available: (
Browse files Browse the repository at this point in the history
…#85)

* fix ensure_digest() call if no headers_buffer is available:
make ensure_digest() a regular method and move header_filter into RecordBuilder to
support optional filter of headers during record creation
add test for calling ensure_digest() on existing record
bump version to 1.7.1

* fix setup.py --doctest-module -> --doctest-modules
  • Loading branch information
ikreymer committed Jul 12, 2019
1 parent 2c3f0c0 commit b963fef
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 14 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from setuptools.command.test import test as TestCommand
import glob

__version__ = '1.7.0'
__version__ = '1.7.1'


class PyTest(TestCommand):
Expand All @@ -18,7 +18,7 @@ def run_tests(self):
import pytest
import sys
import os
errcode = pytest.main(['--doctest-module', './warcio', '--cov', 'warcio', '-v', 'test/'])
errcode = pytest.main(['--doctest-modules', './warcio', '--cov', 'warcio', '-v', 'test/'])
sys.exit(errcode)

setup(
Expand Down
3 changes: 3 additions & 0 deletions test/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,9 @@ def test_request_response_concur(self, is_gzip, builder_factory):

req = sample_request(builder)

# test explicitly calling ensure_digest with block digest enabled on a record
writer.ensure_digest(resp, block=True, payload=True)

writer.write_request_response_pair(req, resp)

stream = writer.get_stream()
Expand Down
19 changes: 10 additions & 9 deletions warcio/recordbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ class RecordBuilder(object):
NO_BLOCK_DIGEST_TYPES = ('warcinfo')


def __init__(self, warc_version=None):
def __init__(self, warc_version=None, header_filter=None):
self.warc_version = self._parse_warc_version(warc_version)

self.header_filter = header_filter

def create_warcinfo_record(self, filename, info):
warc_headers = StatusAndHeaders(self.warc_version, [])
warc_headers.add_header('WARC-Type', 'warcinfo')
Expand Down Expand Up @@ -152,20 +154,19 @@ def _make_warc_id(cls):
def _make_warc_date(cls, use_micros=False):
return datetime_to_iso_date(datetime.datetime.utcnow(), use_micros=use_micros)

@classmethod
def ensure_digest(cls, record, block=True, payload=True):
def ensure_digest(self, record, block=True, payload=True):
if block:
if (record.rec_headers.get_header('WARC-Block-Digest') or
(record.rec_type in cls.NO_BLOCK_DIGEST_TYPES)):
(record.rec_type in self.NO_BLOCK_DIGEST_TYPES)):
block = False

if payload:
if (record.rec_headers.get_header('WARC-Payload-Digest') or
(record.rec_type in cls.NO_PAYLOAD_DIGEST_TYPES)):
(record.rec_type in self.NO_PAYLOAD_DIGEST_TYPES)):
payload = False

block_digester = cls._create_digester() if block else None
payload_digester = cls._create_digester() if payload else None
block_digester = self._create_digester() if block else None
payload_digester = self._create_digester() if payload else None

has_length = (record.length is not None)

Expand All @@ -180,14 +181,14 @@ def ensure_digest(cls, record, block=True, payload=True):
record.raw_stream.seek(pos)
except:
pos = 0
temp_file = cls._create_temp_file()
temp_file = self._create_temp_file()

if block_digester and record.http_headers:
if not record.http_headers.headers_buff:
record.http_headers.compute_headers_buffer(self.header_filter)
block_digester.update(record.http_headers.headers_buff)

for buf in cls._iter_stream(record.raw_stream):
for buf in self._iter_stream(record.raw_stream):
if block_digester:
block_digester.update(buf)

Expand Down
5 changes: 2 additions & 3 deletions warcio/warcwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,13 @@
class BaseWARCWriter(RecordBuilder):

def __init__(self, gzip=True, *args, **kwargs):
super(BaseWARCWriter, self).__init__(warc_version=kwargs.get('warc_version'))
super(BaseWARCWriter, self).__init__(warc_version=kwargs.get('warc_version'),
header_filter=kwargs.get('header_filter'))
self.gzip = gzip
self.hostname = gethostname()

self.parser = StatusAndHeadersParser([], verify=False)

self.header_filter = kwargs.get('header_filter')

def write_request_response_pair(self, req, resp, params=None):
url = resp.rec_headers.get_header('WARC-Target-URI')
dt = resp.rec_headers.get_header('WARC-Date')
Expand Down

0 comments on commit b963fef

Please sign in to comment.