Skip to content
27 changes: 26 additions & 1 deletion qcloud_cos/cos_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .cos_exception import CosServiceError
from .version import __version__
from .select_event_stream import EventStream
from .resumable_downloader import ResumableDownLoader
logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -185,7 +186,7 @@ def __init__(self, conf, retry=1, session=None):
else:
self._session = session

def get_conf():
def get_conf(self):
"""获取配置"""
return self._conf

Expand Down Expand Up @@ -2942,6 +2943,30 @@ def _check_all_upload_parts(self, bucket, key, uploadid, local_path, parts_num,
already_exist_parts[part_num] = part['ETag']
return True

def download_file(self, Bucket, Key, DestFilePath, PartSize=20, MAZThread=5, EnableCRC=False, **Kwargs):
"""小于等于20MB的文件简单下载,大于20MB的文件使用续传下载

:param Bucket(string): 存储桶名称.
:param key(string): COS文件的路径名.
:param DestFilePath(string): 下载文件的目的路径.
:param PartSize(int): 分块下载的大小设置,单位为MB.
:param MAXThread(int): 并发下载的最大线程数.
:param EnableCRC(bool): 校验下载文件与源文件是否一致
:param kwargs(dict): 设置请求headers.
"""
logger.debug("Start to download file, bucket: {0}, key: {1}, dest_filename: {2}, part_size: {3}MB, "
"max_thread: {4}".format(Bucket, Key, DestFilePath, PartSize, MAZThread))

object_info = self.head_object(Bucket, Key)
file_size = object_info['Content-Length']
if file_size <= 1024*1024*20:
response = self.get_object(Bucket, Key, **Kwargs)
response['Body'].get_stream_to_file(DestFilePath)
return

downloader = ResumableDownLoader(self, Bucket, Key, DestFilePath, object_info, PartSize, MAZThread, EnableCRC, **Kwargs)
downloader.start()

def upload_file(self, Bucket, Key, LocalFilePath, PartSize=1, MAXThread=5, EnableMD5=False, **kwargs):
"""小于等于20MB的文件简单上传,大于20MB的文件使用分块上传

Expand Down
206 changes: 206 additions & 0 deletions qcloud_cos/resumable_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# -*- coding: utf-8 -*-

import json
import os
import sys
import threading
import logging
import uuid
import hashlib
import crcmod
from .cos_comm import *
from .streambody import StreamBody
from .cos_threadpool import SimpleThreadPool
logger = logging.getLogger(__name__)

class ResumableDownLoader(object):
def __init__(self, cos_client, bucket, key, dest_filename, object_info, part_size=20, max_thread=5, enable_crc=False, **kwargs):
self.__cos_client = cos_client
self.__bucket = bucket
self.__key = key
self.__dest_file_path = os.path.abspath(dest_filename)
self.__object_info = object_info
self.__max_thread = max_thread
self.__enable_crc = enable_crc
self.__headers = kwargs

self.__max_part_count = 100 # 取决于服务端是否对并发有限制
self.__min_part_size = 1024 * 1024 # 1M
self.__part_size = self.__determine_part_size_internal(int(object_info['Content-Length']), part_size)
self.__finished_parts = []
self.__lock = threading.Lock()
self.__record = None #记录当前的上下文
self.__dump_record_dir = os.path.join(os.path.expanduser('~'), '.cos_download_tmp_file')

record_filename = self.__get_record_filename(bucket, key, self.__dest_file_path)
self.__record_filepath = os.path.join(self.__dump_record_dir, record_filename)
self.__tmp_file = None

if not os.path.exists(self.__dump_record_dir):
os.makedirs(self.__dump_record_dir)

logger.debug('resumale downloader init finish, bucket: {0}, key: {1}'.format(bucket, key))

def start(self):
logger.debug('start resumable downloade, bucket: {0}, key: {1}'.format(self.__bucket, self.__key))
self.__load_record() # 从record文件中恢复读取上下文

assert self.__tmp_file
open(self.__tmp_file, 'a').close()

parts_need_to_download = self.__get_parts_need_to_download()
logger.debug('parts_need_to_download: {0}'.format(parts_need_to_download))
pool = SimpleThreadPool(self.__max_thread)
for part in parts_need_to_download:
part_range = "bytes=" + str(part.start) + "-" + str(part.start + part.length - 1)
headers = dict.copy(self.__headers)
headers["Range"] = part_range
pool.add_task(self.__download_part, part, headers)

pool.wait_completion()
result = pool.get_result()
if not result['success_all']:
raise CosClientError('some download_part fail after max_retry, please downloade_file again')

if os.path.exists(self.__dest_file_path):
os.remove(self.__dest_file_path)
os.rename(self.__tmp_file, self.__dest_file_path)

if self.__enable_crc:
self.__check_crc()

self.__del_record()
logger.debug('download success, bucket: {0}, key: {1}'.format(self.__bucket, self.__key))

def __get_record_filename(self, bucket, key, dest_file_path):
dest_file_path_md5 = hashlib.md5(dest_file_path).hexdigest()
key_md5 = hashlib.md5(key).hexdigest()
return '{0}_{1}.{2}'.format(bucket, key_md5, dest_file_path_md5)

def __determine_part_size_internal(self, file_size, part_size):
real_part_size = part_size * 1024 * 1024 # MB
if real_part_size < self.__min_part_size:
real_part_size = self.__min_part_size

while real_part_size * self.__max_part_count < file_size:
real_part_size = real_part_size * 2
logger.debug('finish to determine part size, file_size: {0}, part_size: {1}'.format(file_size, real_part_size))
return real_part_size

def __splite_to_parts(self):
parts = []
file_size = int(self.__object_info['Content-Length'])
num_parts = (file_size + self.__part_size - 1) / self.__part_size
for i in range(num_parts):
start = i * self.__part_size
if i == num_parts - 1:
length = file_size - start
else:
length = self.__part_size

parts.append(PartInfo(i + 1, start, length))
return parts

def __get_parts_need_to_download(self):
all_set = set(self.__splite_to_parts())
logger.debug('all_set: {0}'.format(len(all_set)))
finished_set = set(self.__finished_parts)
logger.debug('finished_set: {0}'.format(len(finished_set)))
return list(all_set - finished_set)

def __download_part(self, part, headers):
with open(self.__tmp_file, 'rb+') as f:
f.seek(part.start, 0)
range = None
traffic_limit = None
if 'Range' in headers:
range = headers['Range']

if 'TrafficLimit' in headers:
traffic_limit = headers['TrafficLimit']
logger.debug("part_id: {0}, part_range: {1}, traffic_limit:{2}".format(part.part_id, range, traffic_limit))
result = self.__cos_client.get_object(Bucket=self.__bucket, Key=self.__key, **headers)
result["Body"].pget_stream_to_file(f, part.start, part.length)

self.__finish_part(part)

def __finish_part(self, part):
logger.debug('download part finished,bucket: {0}, key: {1}, part_id: {2}'.
format(self.__bucket, self.__key, part.part_id))
with self.__lock:
self.__finished_parts.append(part)
self.__record['parts'].append({'part_id': part.part_id,
'start': part.start,
'length': part.length})
self.__dump_record(self.__record)

def __dump_record(self, record):
with open(self.__record_filepath, 'w') as f:
json.dump(record, f)
logger.debug('dump record to {0}, bucket: {1}, key: {2}'.
format(self.__record_filepath, self.__bucket, self.__key))

def __load_record(self):
record = None

if os.path.exists(self.__record_filepath):
with open(self.__record_filepath, 'r') as f:
record = json.load(f)

ret = self.__check_record(record)
# record记录是否跟head object的一致,不一致则删除
if ret == False:
self.__del_record()
record = None
else:
self.__part_size = record['part_size']
self.__tmp_file = record['tmp_filename']
if not os.path.exists(self.__tmp_file):
record = None
self.__tmp_file = None
self.__del_record()
else:
self.__finished_parts = list(PartInfo(p['part_id'], p['start'], p['length']) for p in record['parts'])
logger.debug('load record: finished parts nums: {0}'.format(len(self.__finished_parts)))
self.__record = record

if not record:
self.__tmp_file = "{file_name}_{uuid}".format(file_name=self.__dest_file_path, uuid=uuid.uuid4().hex)
record = {'bucket': self.__bucket, 'key': self.__key, 'tmp_filename':self.__tmp_file,
'mtime':self.__object_info['Last-Modified'], 'etag':self.__object_info['ETag'],
'file_size':self.__object_info['Content-Length'], 'part_size': self.__part_size, 'parts':[]}
self.__record = record
self.__dump_record(record)

def __check_record(self, record):
return record['etag'] == self.__object_info['ETag'] and\
record['mtime'] == self.__object_info['Last-Modified'] and\
record['file_size'] == self.__object_info['Content-Length']

def __del_record(self):
os.remove(self.__record_filepath)
logger.debug('ResumableDownLoader delete record_file, path: {0}'.format(self.__record_filepath))

def __check_crc(self):
logger.debug('start to check crc')
c64 = crcmod.mkCrcFun(0x142F0E1EBA9EA3693L, initCrc=0L, xorOut=0xffffffffffffffffL, rev=True)
with open(self.__dest_file_path,'rb') as f:
local_crc64 = str(c64(f.read()))
object_crc64 = self.__object_info['x-cos-hash-crc64ecma']
if local_crc64 is not None and object_crc64 is not None and local_crc64 != object_crc64:
raise CosClientError('crc of client: {0} is mismatch with cos: {1}'.format(local_crc64, object_crc64))

class PartInfo(object):
def __init__(self, part_id, start, length):
self.part_id = part_id
self.start = start
self.length = length

def __eq__(self, other):
return self.__key() == other.__key()

def __hash__(self):
return hash(self.__key())

def __key(self):
return self.part_id, self.start, self.length
31 changes: 31 additions & 0 deletions qcloud_cos/streambody.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,34 @@ def get_stream_to_file(self, file_name, auto_decompress=False):
if os.path.exists(file_name):
os.remove(file_name)
os.rename(tmp_file_name, file_name)

def pget_stream_to_file(self, fdst, offset, expected_len, auto_decompress=False):
"""保存流到本地文件的offset偏移"""
use_chunked = False
use_encoding = False
if 'Transfer-Encoding' in self._rt.headers and self._rt.headers['Transfer-Encoding'] == "chunked":
use_chunked = True
elif 'Content-Length' not in self._rt.headers:
raise IOError("download failed without Content-Length header or Transfer-Encoding header")

if 'Content-Encoding' in self._rt.headers:
use_encoding = True
read_len = 0
fdst.seek(offset, 0)

if use_encoding and not auto_decompress:
chunk = self._rt.raw.read(1024)
while chunk:
read_len += len(chunk)
fdst.write(chunk)
chunk = self._rt.raw.read(1024)
else:
for chunk in self._rt.iter_content(chunk_size=1024):
if chunk:
read_len += len(chunk)
fdst.write(chunk)


if not use_chunked and not (use_encoding and auto_decompress) and read_len != expected_len:
raise IOError("download failed with incomplete file")

48 changes: 48 additions & 0 deletions ut/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,53 @@ def _test_get_object_sensitive_content_recognition():
print(response)
assert response

def test_download_file():
"""测试断点续传下载接口"""
#测试普通下载
client.download_file(copy_test_bucket, test_object, 'test_download_file.local')
if os.path.exists('test_download_file.local'):
os.remove('test_download_file.local')

# 测试限速下载
client.download_file(copy_test_bucket, test_object, 'test_download_traffic_limit.local', TrafficLimit='819200')
if os.path.exists('test_download_traffic_limit.local'):
os.remove('test_download_traffic_limit.local')

# 测试crc64校验开关
client.download_file(copy_test_bucket, test_object, 'test_download_crc.local', EnableCRC=True)
if os.path.exists('test_download_crc.local'):
os.remove('test_download_crc.local')

# 测试源文件的md5与下载下来后的文件md5
file_size = 25 # MB
file_id = str(random.randint(0, 1000)) + str(random.randint(0, 1000))
file_name = "tmp" + file_id + "_" + str(file_size) + "MB"
gen_file(file_name, file_size)

source_file_md5 = None
dest_file_md5 = None
with open(file_name, 'rb') as f:
source_file_md5 = get_raw_md5(f.read())

client.put_object_from_local_file(
Bucket=copy_test_bucket,
LocalFilePath=file_name,
Key=file_name
)

client.download_file(copy_test_bucket, file_name, 'test_download_md5.local')
if os.path.exists('test_download_md5.local'):
with open('test_download_md5.local', 'rb') as f:
dest_file_md5 = get_raw_md5(f.read())
assert source_file_md5 and dest_file_md5 and source_file_md5 == dest_file_md5

# 释放资源
client.delete_object(
Bucket=copy_test_bucket,
Key=file_name
)
if os.path.exists(file_name):
os.remove(file_name)

if __name__ == "__main__":
setUp()
Expand All @@ -1190,6 +1237,7 @@ def _test_get_object_sensitive_content_recognition():
test_put_get_delete_bucket_domain()
test_select_object()
_test_get_object_sensitive_content_recognition()
test_download_file()
"""

tearDown()