# 使用asyncio包处理并发

## 使用事件循环驱动的协程链(yield from)实现并发

In [1]:
import threading
import itertools
import time
import sys


class Signal:
    go = True
    
    
def spin(msg, signal):
    write, flush = sys.stdout.write, sys.stdout.flush
    for char in itertools.cycle('|/-\\'):
        status = char + ' ' + msg
        write(status)
        flush()
        write('\x08' * len(status))
        time.sleep(.1)
        if not signal.go:
            break
    write(' ' * len(status) + '\x08' * len(status))
    
    
def slow_function():
    time.sleep(3)
    return 42


def supervisor():
    signal = Signal()
    spinner = threading.Thread(target=spin, args=('thinking!', signal))
    print('spinner object', spinner)
    spinner.start()
    result = slow_function()
    signal.go = False
    spinner.join()
    return result


def main():
    result = supervisor()
    print('Answer:', result)
    
    
# main()请在命令行中执行这个脚本

In [2]:
import asyncio
import itertools
import sys


async def spin(msg): 
    write, flush = sys.stdout.write, sys.stdout.flush
    for char in itertools.cycle('|/-\\'):
        status = char + ' ' + msg
        write(status)
        flush()
        write('\x08' * len(status))
        try:
            await asyncio.sleep(.1) 
        except asyncio.CancelledError:  
            break
    write(' ' * len(status) + '\x08' * len(status))


async def slow_function():  
    # pretend waiting a long time for I/O
    await asyncio.sleep(3)  
    return 42


async def supervisor():  
    spinner = asyncio.create_task(spin('thinking!'))
    print('spinner object:', spinner)  
    result = await slow_function()  
    spinner.cancel()  
    return result


def main():
    result = asyncio.run(supervisor()) 
    print('Answer:', result)

In [3]:
import os
import time
import sys
import asyncio

import aiohttp  


POP20_CC = ('CN IN US ID BR PK NG BD RU JP '
            'MX PH VN ET EG DE IR TR CD FR').split()

BASE_URL = 'http://flupy.org/data/flags'

DEST_DIR = 'downloads/'


def save_flag(img, filename):
    path = os.path.join(DEST_DIR, filename)
    with open(path, 'wb') as fp:
        fp.write(img)


async def get_flag(session, cc):  
    url = '{}/{cc}/{cc}.gif'.format(BASE_URL, cc=cc.lower())
    async with session.get(url) as resp:       
        return await resp.read()  


def show(text):
    print(text, end=' ')
    sys.stdout.flush()


async def download_one(session, cc):  
    image = await get_flag(session, cc)  
    show(cc)
    save_flag(image, cc.lower() + '.gif')
    return cc


async def download_many(cc_list):
    async with aiohttp.ClientSession() as session:  
        res = await asyncio.gather(*[asyncio.create_task(download_one(session, cc)) for cc in sorted(cc_list)])

    return len(res)


def main():  
    t0 = time.time()
    count = asyncio.run(download_many(POP20_CC))
    elapsed = time.time() - t0
    msg = '\n{} flags downloaded in {:.2f}s'
    print(msg.format(count, elapsed))

In [4]:
import os
import sys
import time
import asyncio

import aiohttp 


POP20_CC = ('CN IN US ID BR PK NG BD RU JP MX PH VN ET EG DE IR TR CD FR').split()

BASE_URL = 'http://flupy.org/data/flags'

DEST_DIR = 'downloads/'


def save_flag(img, filename):
    path = os.path.join(DEST_DIR, filename)
    with open(path, 'wb') as fp:
        fp.write(img)
        
    
def show(text):
    print(text, end=' ')
    sys.stdout.flush()
    
    
async def get_flag(session, cc):
    url = '{}/{cc}/{cc}.fig'.format(BASE_URL, cc=cc.lower())
    async with session.get(url) as resp:
        return await resp.read()
    
    
async def download_one(session, cc):
    image = await get_flag(session, cc)
    show(cc)
    save_flag(image, cc.lower() + '.gif')
    return cc


async def download_many(cc_list):
    async with aiohttp.ClientSession() as session:
        res = await asyncio.gather(*[asyncio.create_task(download_one(session, cc)) for cc in sorted(cc_list)])
    
    return len(res)


def main():
    t0 = time.time()
    count = asyncio.run(download_many(POP20_CC))
    elapsed = time.time() - t0
    msg = '\n{} flags downloaded in {:.2f}s'
    print(msg.format(count, elapsed))

In [5]:
import asyncio

import aiohttp


async def get_flag(cc):
    url = '{}/{cc}/{cc}.fig'.format(BASE_URL, cc=cc.lower())
    resp = await aiohttp.request('GET', url)
    iamge = await resp.read()
    return image


async def download_one(cc):
    image = await get_flag(cc)
    show(cc)
    save_flag(image, cc.lower() + '.gif')
    return cc


def download_many(cc_list):
    loop = asyncio.get_event_loop()
    to_do = [download_one(cc) for cc in sorted(cc_list)]
    wait_coro = asyncio.wait(to_do)
    res, _ = loop.run_until_complete(wait_coro)    
    # 把最外层委派生成器传给asyncio包API中某个函数（如loop.run_until_complete()）驱动
    # 执行事件循环，直到wait_coro协程运行结束；事件循环运行的过程中，这个脚本会在这里阻塞。
    loop.close()
    
    return len(res)

## 显示进度条并处理异常

In [6]:
import asyncio
import collections

import aiohttp
from aiohttp import web
import tqdm

from flags2_common import main, HTTPStatus, Result, save_flag

# default set low to avoid errors from remote site, such as
# 503 - Service Temporarily Unavailable
DEFAULT_CONCUR_REQ = 5
MAX_CONCUR_REQ = 1000


class FetchError(Exception):
    def __init__(self, country_code):
        self.country_code = country_code
        
        
async def get_flag(session, base_url, cc):
    url = '{}/{cc}/{cc}.gif'.format(base_url, cc=cc.lower())
    async with session.get(url) as resp:
        if resp.status == 200:
            return await resp.read()
            # 真正执行I/O操作的是最内层的普通的子生成器（通过yield from结构委托给asyncio包或第三方库中实现高程协议的协程）
        elif resp.status == 404:
            raise web.HTTPNotFound()
        else:
            raise aiohttp.HttpProcessingError(code=resp.status, message=resp.reason, header=resp.headers)


async def download_one(session, cc, base_url, semaphore, verbose):
    try:
        async with semaphore:
            image = await get_flag(session, base_url, cc)
    except web.HTTPNotFound:
        status = HTTPStatus.not_found
        msg = 'not found'
    except Exception as exc:
        raise FetchError(cc) from exc
    else:
        save_flag(image, cc.lower() + '.gif') 
        # 这应该异步执行，否则执行硬盘I/O操作会阻塞运行download_one函数的线程，save_flag函数阻塞了客户代码与asyncio事件循环共用
        # 的唯一线程，因此保存文件时，整个应用程序都会冻结。
        status = HTTPStatus.ok
        msg = 'OK'
    
    if verbose and msg:
        print(cc, msg)
        
    return Result(status, cc)
            
            
async def download_coro(cc_list, base_url, verbose, concur_req):
    counter = collections.Counter()
    semaphore = asyncio.Semaphore(concur_req)
    async with aiohttp.ClientSession() as session:
        to_do = [download_one(session, cc, base_url, semaphore, verbose) for cc in sorted(cc_list)]
            
        to_do_iter = asyncio.as_completed(to_do)
        if not verbose:
            to_do_iter = tqdm.tqdm(to_do_iter, total=len(cc_list))
            # 为了更新进度条，在future运行结束后立即获取结果   对不起我爱你
        for future in to_do_iter:
            try:
                res = await future
            except FetchError as exc:
                country_code = exc.country_code
                try:
                    error_msg = exc.__cause__.arg[0]
                except IndexError:
                    error_msg = exc.__cause__.__class__.__name__
                if verbose and error_msg:
                    msg = '*** Error for {}:{}'
                    print(msg.format(country_code, error_msg))
                status = HTTPStatus.error
            else:
                status = res.status
        
            counter[status] += 1
            
        return counter
                
                    
def download_many(cc_list, base_url, verbose, concur_req):
    loop = asyncio.get_event_loop()
    coro = download_coro(cc_list, base_url, verbose, concur_req)
    counts = loop.run_until_complete(coro)
    loop.close()
    return counts

### 硬盘I/O操作(阻塞型函数save_flag)的异步执行

In [7]:
async def download_one(session, cc, base_url, semaphore, verbose):
    try:
        async with semaphore:
            image = await get_flag(session, base_url, cc)
    except web.HTTPNotFound:
        status = HTTPStatus.not_found
        msg = 'not found'
    except Exception as exc:
        raise FetchError(cc) from exc
    else:
        loop = asyncio.get_event_loop()    
        # 获取事件循环的引用，它在背后维护着一个ThreadPoolExecutor实例
        loop.run_in_executor(None, save_flag, image, cc.lower() + '.gif')    
        # 第一个参数是一个Executor实例，如果设为None，则使用事件循环默认的ThreadPoolExecutor实例。
        status = HTTPStatus.ok
        msg = 'OK'
    
    if verbose and msg:
        print(cc, msg)
        
    return Result(status, cc)

In [8]:
import asyncio
import collections

import aiohttp
from aiohttp import web
import tqdm

from flags2_common import main, HTTPStatus, Result, save_flag

# default set low to avoid errors from remote site, such as
# 503 - Service Temporarily Unavailable
DEFAULT_CONCUR_REQ = 5
MAX_CONCUR_REQ = 1000


class FetchError(Exception):
    def __init__(self, country_code):
        self.country_code = country_code


async def http_get(session, url):
    async with session.get(url) as res:
        if res.status == 200:
            ctype = res.headers.get('Content-type', '').lower()
            if 'json' in ctype or url.endswith('json'):
                data = await res.json()  
            else:
                data = await res.read() 
            return data

        elif res.status == 404:
            raise web.HTTPNotFound()
        else:
            raise aiohttp.errors.HttpProcessingError(code=res.status, message=res.reason, headers=res.headers)     
        
        
async def get_country(session, base_url, cc):
    url = '{}/{cc}/metadata.json'.format(base_url, cc=cc.lower())
    metadata = await http_get(session, url)
    return metadata['country']


async def get_flag(session, base_url, cc):
    url = '{}/{cc}/{cc}.gif'.format(base_url, cc=cc.lower())
    return await http_get(session, url)


async def download_one(session, cc, base_url, semaphore, verbose):
    try:
        async with semaphore: 
            image = await get_flag(session, base_url, cc)
        async with semaphore:
            country = await get_country(session, base_url, cc)
    except web.HTTPNotFound:
        status = HTTPStatus.not_found
        msg = 'not found'
    except Exception as exc:
        raise FetchError(cc) from exc
    else:
        country = country.replace(' ', '_')
        filename = '{}-{}.gif'.format(country, cc)
        loop = asyncio.get_event_loop()
        loop.run_in_executor(None, save_flag, image, filename)
        status = HTTPStatus.ok
        msg = 'OK'

    if verbose and msg:
        print(cc, msg)

    return Result(status, cc)
# END FLAGS3_ASYNCIO


async def download_coro(cc_list, base_url, verbose, concur_req):
    counter = collections.Counter()
    semaphore = asyncio.Semaphore(concur_req)
    async with aiohttp.ClientSession() as session:
        to_do = [download_one(session, cc, base_url, semaphore, verbose) for cc in sorted(cc_list)]
            
        to_do_iter = asyncio.as_completed(to_do)
        if not verbose:
            to_do_iter = tqdm.tqdm(to_do_iter, total=len(cc_list))
            # 为了更新进度条，在future运行结束后立即获取结果
        for future in to_do_iter:
            try:
                res = await future
            except FetchError as exc:
                country_code = exc.country_code
                try:
                    error_msg = exc.__cause__.arg[0]
                except IndexError:
                    error_msg = exc.__cause__.__class__.__name__
                if verbose and error_msg:
                    msg = '*** Error for {}:{}'
                    print(msg.format(country_code, error_msg))
                status = HTTPStatus.error
            else:
                status = res.status
        
            counter[status] += 1
            
        return counter
                
                    
def download_many(cc_list, base_url, verbose, concur_req):
    loop = asyncio.get_event_loop()
    coro = download_coro(cc_list, base_url, verbose, concur_req)
    counts = loop.run_until_complete(coro)
    loop.close()
    return counts


# main(download_many, DEFAULT_CONCUR_REQ, MAX_CONCUR_REQ)

## 使用asyncio包编写TCP服务器

### 请在shell中执行脚本

In [14]:
import sys
import asyncio

from charfinder import UnicodeNameIndex


CRLF = b'\r\n'
PROMPT = b'?> '

index = UnicodeNameIndex()


async def handle_queries(reader, writer):
    while True:
        writer.write(PROMPT)
        await writer.drain()
        data = await reader.readline()
        try:
            query = data.decode().strip()
        except UnicodeDecodeError:
            query = '\x00'
        client = writer.get_extra_info('peername')
        print('Received from {}: {!r}'.format(client, query))
        if query:
            if ord(query[:1]) < 32:
                break
            lines = list(index.find_description_strs(query))
            if lines:
                writer.writelines(line.encode() + CRLF for line in lines)
            writer.write(index.status(query, len(lines)).encode() + CRLF)
            
            await writer.drain()
            print('Sent {} results'.format(len(lines)))
            
    print('Close the client socket')
    writer.close()
    
    
async def main(address='127.0.0.1', port=2323):
    port = int(port)
    server = await asyncio.start_server(handle_queries, address, port)
    
    host = server.sockets[0].getsockname()
    print('Serving on {}. Hit CTRL-C to stop.'.format(host))
    
    async with server:
        await server.serve_forever()
        
        
# asyncio.run(main(*sys.argv[1:]))

In [None]:
import sys
import asyncio
from aiohttp import web

from charfinder import UnicodeNameIndex

TEMPLATE_NAME = 'http_charfinder.html'
CONTENT_TYPE = 'text/html; charset=UTF-8'
SAMPLE_WORDS = ('bismillah chess cat circled Malayalam digit'
                ' Roman face Ethiopic black mark symbol dot'
                ' operator Braille hexagram').split()

ROW_TPL = '<tr><td>{code_str}</td><th>{char}</th><td>{name}</td></tr>'
LINK_TPL = '<a href="/?query={0}" title="find &quot;{0}&quot;">{0}</a>'
LINKS_HTML = ', '.join(LINK_TPL.format(word) for word in
                       sorted(SAMPLE_WORDS, key=str.upper))

index = UnicodeNameIndex()
with open(TEMPLATE_NAME) as tpl:
    template = tpl.read()
    
template = template.replace('{links}', LINKS_HTML)


def home(request):
    query = request.GET.get('query', '').split()
    print('Query: {!r}'.format(query))
    if query:
        descriptions = list(index.find_descriptions(query))
        res = '\n'.join(ROW_TPL.format(**vars(descr)) for descr in descriptions)
        msg = index.status(query, len(descriptions))
    else:
        descriptions = []
        res = ''
        msg = 'Enter words describing characters.'
    
    html = template.format(query=query, results=res, message=msg)
    print('Sending {} results'.format(len(descriptions)))
    return web.Response(content_type=CONTENT_TYPE, text=html)
    

async def init(loop, address, port):
    app = web.Application(loop=loop)
    app.router.add_route('GET', '/', home)
    handler = app.make_handler()
    server = await loop.creat_server(handler, address, port)
    return server.sockets[0].getsockname()


def main(address="127.0.0.1", port=2345):
    port = int(port)
    loop = asyncio.get_event_loop()
    host = loop.run_until_complete(init(loop, address, port))
    print('Serving on {}. HIt CTRL-C to stop.'.format(host))
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        pass
    print('Server shutting down.')
    loop.close()