diff --git a/qcloud_cos/cos_client.py b/qcloud_cos/cos_client.py index 7028e73b..43680238 100644 --- a/qcloud_cos/cos_client.py +++ b/qcloud_cos/cos_client.py @@ -25,6 +25,7 @@ from .cos_exception import CosServiceError from .version import __version__ from .select_event_stream import EventStream +from .resumable_downloader import ResumableDownLoader logger = logging.getLogger(__name__) @@ -185,7 +186,7 @@ def __init__(self, conf, retry=1, session=None): else: self._session = session - def get_conf(): + def get_conf(self): """获取配置""" return self._conf @@ -2942,6 +2943,30 @@ def _check_all_upload_parts(self, bucket, key, uploadid, local_path, parts_num, already_exist_parts[part_num] = part['ETag'] return True + def download_file(self, Bucket, Key, DestFilePath, PartSize=20, MAZThread=5, EnableCRC=False, **Kwargs): + """小于等于20MB的文件简单下载,大于20MB的文件使用续传下载 + + :param Bucket(string): 存储桶名称. + :param key(string): COS文件的路径名. + :param DestFilePath(string): 下载文件的目的路径. + :param PartSize(int): 分块下载的大小设置,单位为MB. + :param MAXThread(int): 并发下载的最大线程数. + :param EnableCRC(bool): 校验下载文件与源文件是否一致 + :param kwargs(dict): 设置请求headers. + """ + logger.debug("Start to download file, bucket: {0}, key: {1}, dest_filename: {2}, part_size: {3}MB, " + "max_thread: {4}".format(Bucket, Key, DestFilePath, PartSize, MAZThread)) + + object_info = self.head_object(Bucket, Key) + file_size = object_info['Content-Length'] + if file_size <= 1024*1024*20: + response = self.get_object(Bucket, Key, **Kwargs) + response['Body'].get_stream_to_file(DestFilePath) + return + + downloader = ResumableDownLoader(self, Bucket, Key, DestFilePath, object_info, PartSize, MAZThread, EnableCRC, **Kwargs) + downloader.start() + def upload_file(self, Bucket, Key, LocalFilePath, PartSize=1, MAXThread=5, EnableMD5=False, **kwargs): """小于等于20MB的文件简单上传,大于20MB的文件使用分块上传 diff --git a/qcloud_cos/resumable_downloader.py b/qcloud_cos/resumable_downloader.py new file mode 100644 index 00000000..03f6fca5 --- /dev/null +++ b/qcloud_cos/resumable_downloader.py @@ -0,0 +1,206 @@ +# -*- coding: utf-8 -*- + +import json +import os +import sys +import threading +import logging +import uuid +import hashlib +import crcmod +from .cos_comm import * +from .streambody import StreamBody +from .cos_threadpool import SimpleThreadPool +logger = logging.getLogger(__name__) + +class ResumableDownLoader(object): + def __init__(self, cos_client, bucket, key, dest_filename, object_info, part_size=20, max_thread=5, enable_crc=False, **kwargs): + self.__cos_client = cos_client + self.__bucket = bucket + self.__key = key + self.__dest_file_path = os.path.abspath(dest_filename) + self.__object_info = object_info + self.__max_thread = max_thread + self.__enable_crc = enable_crc + self.__headers = kwargs + + self.__max_part_count = 100 # 取决于服务端是否对并发有限制 + self.__min_part_size = 1024 * 1024 # 1M + self.__part_size = self.__determine_part_size_internal(int(object_info['Content-Length']), part_size) + self.__finished_parts = [] + self.__lock = threading.Lock() + self.__record = None #记录当前的上下文 + self.__dump_record_dir = os.path.join(os.path.expanduser('~'), '.cos_download_tmp_file') + + record_filename = self.__get_record_filename(bucket, key, self.__dest_file_path) + self.__record_filepath = os.path.join(self.__dump_record_dir, record_filename) + self.__tmp_file = None + + if not os.path.exists(self.__dump_record_dir): + os.makedirs(self.__dump_record_dir) + + logger.debug('resumale downloader init finish, bucket: {0}, key: {1}'.format(bucket, key)) + + def start(self): + logger.debug('start resumable downloade, bucket: {0}, key: {1}'.format(self.__bucket, self.__key)) + self.__load_record() # 从record文件中恢复读取上下文 + + assert self.__tmp_file + open(self.__tmp_file, 'a').close() + + parts_need_to_download = self.__get_parts_need_to_download() + logger.debug('parts_need_to_download: {0}'.format(parts_need_to_download)) + pool = SimpleThreadPool(self.__max_thread) + for part in parts_need_to_download: + part_range = "bytes=" + str(part.start) + "-" + str(part.start + part.length - 1) + headers = dict.copy(self.__headers) + headers["Range"] = part_range + pool.add_task(self.__download_part, part, headers) + + pool.wait_completion() + result = pool.get_result() + if not result['success_all']: + raise CosClientError('some download_part fail after max_retry, please downloade_file again') + + if os.path.exists(self.__dest_file_path): + os.remove(self.__dest_file_path) + os.rename(self.__tmp_file, self.__dest_file_path) + + if self.__enable_crc: + self.__check_crc() + + self.__del_record() + logger.debug('download success, bucket: {0}, key: {1}'.format(self.__bucket, self.__key)) + + def __get_record_filename(self, bucket, key, dest_file_path): + dest_file_path_md5 = hashlib.md5(dest_file_path).hexdigest() + key_md5 = hashlib.md5(key).hexdigest() + return '{0}_{1}.{2}'.format(bucket, key_md5, dest_file_path_md5) + + def __determine_part_size_internal(self, file_size, part_size): + real_part_size = part_size * 1024 * 1024 # MB + if real_part_size < self.__min_part_size: + real_part_size = self.__min_part_size + + while real_part_size * self.__max_part_count < file_size: + real_part_size = real_part_size * 2 + logger.debug('finish to determine part size, file_size: {0}, part_size: {1}'.format(file_size, real_part_size)) + return real_part_size + + def __splite_to_parts(self): + parts = [] + file_size = int(self.__object_info['Content-Length']) + num_parts = (file_size + self.__part_size - 1) / self.__part_size + for i in range(num_parts): + start = i * self.__part_size + if i == num_parts - 1: + length = file_size - start + else: + length = self.__part_size + + parts.append(PartInfo(i + 1, start, length)) + return parts + + def __get_parts_need_to_download(self): + all_set = set(self.__splite_to_parts()) + logger.debug('all_set: {0}'.format(len(all_set))) + finished_set = set(self.__finished_parts) + logger.debug('finished_set: {0}'.format(len(finished_set))) + return list(all_set - finished_set) + + def __download_part(self, part, headers): + with open(self.__tmp_file, 'rb+') as f: + f.seek(part.start, 0) + range = None + traffic_limit = None + if 'Range' in headers: + range = headers['Range'] + + if 'TrafficLimit' in headers: + traffic_limit = headers['TrafficLimit'] + logger.debug("part_id: {0}, part_range: {1}, traffic_limit:{2}".format(part.part_id, range, traffic_limit)) + result = self.__cos_client.get_object(Bucket=self.__bucket, Key=self.__key, **headers) + result["Body"].pget_stream_to_file(f, part.start, part.length) + + self.__finish_part(part) + + def __finish_part(self, part): + logger.debug('download part finished,bucket: {0}, key: {1}, part_id: {2}'. + format(self.__bucket, self.__key, part.part_id)) + with self.__lock: + self.__finished_parts.append(part) + self.__record['parts'].append({'part_id': part.part_id, + 'start': part.start, + 'length': part.length}) + self.__dump_record(self.__record) + + def __dump_record(self, record): + with open(self.__record_filepath, 'w') as f: + json.dump(record, f) + logger.debug('dump record to {0}, bucket: {1}, key: {2}'. + format(self.__record_filepath, self.__bucket, self.__key)) + + def __load_record(self): + record = None + + if os.path.exists(self.__record_filepath): + with open(self.__record_filepath, 'r') as f: + record = json.load(f) + + ret = self.__check_record(record) + # record记录是否跟head object的一致,不一致则删除 + if ret == False: + self.__del_record() + record = None + else: + self.__part_size = record['part_size'] + self.__tmp_file = record['tmp_filename'] + if not os.path.exists(self.__tmp_file): + record = None + self.__tmp_file = None + self.__del_record() + else: + self.__finished_parts = list(PartInfo(p['part_id'], p['start'], p['length']) for p in record['parts']) + logger.debug('load record: finished parts nums: {0}'.format(len(self.__finished_parts))) + self.__record = record + + if not record: + self.__tmp_file = "{file_name}_{uuid}".format(file_name=self.__dest_file_path, uuid=uuid.uuid4().hex) + record = {'bucket': self.__bucket, 'key': self.__key, 'tmp_filename':self.__tmp_file, + 'mtime':self.__object_info['Last-Modified'], 'etag':self.__object_info['ETag'], + 'file_size':self.__object_info['Content-Length'], 'part_size': self.__part_size, 'parts':[]} + self.__record = record + self.__dump_record(record) + + def __check_record(self, record): + return record['etag'] == self.__object_info['ETag'] and\ + record['mtime'] == self.__object_info['Last-Modified'] and\ + record['file_size'] == self.__object_info['Content-Length'] + + def __del_record(self): + os.remove(self.__record_filepath) + logger.debug('ResumableDownLoader delete record_file, path: {0}'.format(self.__record_filepath)) + + def __check_crc(self): + logger.debug('start to check crc') + c64 = crcmod.mkCrcFun(0x142F0E1EBA9EA3693L, initCrc=0L, xorOut=0xffffffffffffffffL, rev=True) + with open(self.__dest_file_path,'rb') as f: + local_crc64 = str(c64(f.read())) + object_crc64 = self.__object_info['x-cos-hash-crc64ecma'] + if local_crc64 is not None and object_crc64 is not None and local_crc64 != object_crc64: + raise CosClientError('crc of client: {0} is mismatch with cos: {1}'.format(local_crc64, object_crc64)) + +class PartInfo(object): + def __init__(self, part_id, start, length): + self.part_id = part_id + self.start = start + self.length = length + + def __eq__(self, other): + return self.__key() == other.__key() + + def __hash__(self): + return hash(self.__key()) + + def __key(self): + return self.part_id, self.start, self.length diff --git a/qcloud_cos/streambody.py b/qcloud_cos/streambody.py index e373e807..3bce1055 100644 --- a/qcloud_cos/streambody.py +++ b/qcloud_cos/streambody.py @@ -53,3 +53,34 @@ def get_stream_to_file(self, file_name, auto_decompress=False): if os.path.exists(file_name): os.remove(file_name) os.rename(tmp_file_name, file_name) + + def pget_stream_to_file(self, fdst, offset, expected_len, auto_decompress=False): + """保存流到本地文件的offset偏移""" + use_chunked = False + use_encoding = False + if 'Transfer-Encoding' in self._rt.headers and self._rt.headers['Transfer-Encoding'] == "chunked": + use_chunked = True + elif 'Content-Length' not in self._rt.headers: + raise IOError("download failed without Content-Length header or Transfer-Encoding header") + + if 'Content-Encoding' in self._rt.headers: + use_encoding = True + read_len = 0 + fdst.seek(offset, 0) + + if use_encoding and not auto_decompress: + chunk = self._rt.raw.read(1024) + while chunk: + read_len += len(chunk) + fdst.write(chunk) + chunk = self._rt.raw.read(1024) + else: + for chunk in self._rt.iter_content(chunk_size=1024): + if chunk: + read_len += len(chunk) + fdst.write(chunk) + + + if not use_chunked and not (use_encoding and auto_decompress) and read_len != expected_len: + raise IOError("download failed with incomplete file") + diff --git a/ut/test.py b/ut/test.py index 2de93b58..12151f61 100644 --- a/ut/test.py +++ b/ut/test.py @@ -1164,6 +1164,53 @@ def _test_get_object_sensitive_content_recognition(): print(response) assert response +def test_download_file(): + """测试断点续传下载接口""" + #测试普通下载 + client.download_file(copy_test_bucket, test_object, 'test_download_file.local') + if os.path.exists('test_download_file.local'): + os.remove('test_download_file.local') + + # 测试限速下载 + client.download_file(copy_test_bucket, test_object, 'test_download_traffic_limit.local', TrafficLimit='819200') + if os.path.exists('test_download_traffic_limit.local'): + os.remove('test_download_traffic_limit.local') + + # 测试crc64校验开关 + client.download_file(copy_test_bucket, test_object, 'test_download_crc.local', EnableCRC=True) + if os.path.exists('test_download_crc.local'): + os.remove('test_download_crc.local') + + # 测试源文件的md5与下载下来后的文件md5 + file_size = 25 # MB + file_id = str(random.randint(0, 1000)) + str(random.randint(0, 1000)) + file_name = "tmp" + file_id + "_" + str(file_size) + "MB" + gen_file(file_name, file_size) + + source_file_md5 = None + dest_file_md5 = None + with open(file_name, 'rb') as f: + source_file_md5 = get_raw_md5(f.read()) + + client.put_object_from_local_file( + Bucket=copy_test_bucket, + LocalFilePath=file_name, + Key=file_name + ) + + client.download_file(copy_test_bucket, file_name, 'test_download_md5.local') + if os.path.exists('test_download_md5.local'): + with open('test_download_md5.local', 'rb') as f: + dest_file_md5 = get_raw_md5(f.read()) + assert source_file_md5 and dest_file_md5 and source_file_md5 == dest_file_md5 + + # 释放资源 + client.delete_object( + Bucket=copy_test_bucket, + Key=file_name + ) + if os.path.exists(file_name): + os.remove(file_name) if __name__ == "__main__": setUp() @@ -1190,6 +1237,7 @@ def _test_get_object_sensitive_content_recognition(): test_put_get_delete_bucket_domain() test_select_object() _test_get_object_sensitive_content_recognition() + test_download_file() """ tearDown()