# 基于类的上下文管理器

In [1]:
class FileManager:
    def __init__(self, name, mode):
        print('calling __init__ method')
        self.name = name
        self.mode = mode 
        self.file = None
        
    def __enter__(self):
        print('calling __enter__ method')
        self.file = open(self.name, self.mode)
        return self.file


    def __exit__(self, exc_type, exc_val, exc_tb):
        print('calling __exit__ method')
        if self.file:
            self.file.close()
            
with FileManager('test.txt', 'w') as f:
    print('ready to write to file')
    f.write('hello world')
    


calling __init__ method
calling __enter__ method
ready to write to file
calling __exit__ method


In [2]:
with FileManager('test.txt', 'w') as f:
    f.write('hello world')

calling __init__ method
calling __enter__ method
calling __exit__ method


In [5]:
class Foo:
    def __init__(self):
        print('__init__ called')        

    def __enter__(self):
        print('__enter__ called')
        return self
    
    def __exit__(self, exc_type, exc_value, exc_tb):
        print('__exit__ called')
        if exc_type:
            print('exc_type: {}'.format(exc_type))
            print('exc_value: {}'.format(exc_value))
            print('exc_traceback: {}'.format(exc_tb))
            print('exception handled')
        return True
    
with Foo() as obj:
    raise Exception('exception raised').with_traceback(None)


__init__ called
__enter__ called
__exit__ called
exc_type: <class 'Exception'>
exc_value: exception raised
exc_traceback: <traceback object at 0x7f2d94135108>
exception handled


# 单元测试

In [6]:
import unittest

# 将要被测试的排序函数
def sort(arr):
    l = len(arr)
    for i in range(0, l):
        for j in range(i + 1, l):
            if arr[i] >= arr[j]:
                tmp = arr[i]
                arr[i] = arr[j]
                arr[j] = tmp


# 编写子类继承 unittest.TestCase
class TestSort(unittest.TestCase):

   # 以 test 开头的函数将会被测试
   def test_sort(self):
        arr = [3, 4, 1, 5, 6]
        sort(arr)
        # assert 结果跟我们期待的一样
        self.assertEqual(arr, [1, 3, 4, 5, 6])

if __name__ == '__main__':
    ## 如果在 Jupyter 下，请用如下方式运行单元测试
    unittest.main(argv=['first-arg-is-ignored'], exit=False)
    
    ## 如果是命令行下运行，则：
    ## unittest.main()
    


.
----------------------------------------------------------------------
Ran 1 test in 0.003s

OK


## mock unit测试

In [7]:
import unittest
from unittest.mock import MagicMock

class A(unittest.TestCase):
    def m1(self):
        val = self.m2()
        self.m3(val)

    def m2(self):
        pass

    def m3(self, val):
        pass

    def test_m1(self):
        a = A()
        a.m2 = MagicMock(return_value="custom_val")
        a.m3 = MagicMock()
        a.m1()
        self.assertTrue(a.m2.called) # 验证 m2 被 call 过
        a.m3.assert_called_with("custom_val") # 验证 m3 被指定参数 call 过
        
if __name__ == '__main__':
    unittest.main(argv=['first-arg-is-ignored'], exit=False)

..
----------------------------------------------------------------------
Ran 2 tests in 0.006s

OK


## mock side_effect

In [11]:
from unittest.mock import MagicMock
def side_effect(arg):
    if arg < 0:
        return 1
    else:
        return 2
mock = MagicMock()
mock.side_effect = side_effect
mock(-1)

1

In [10]:
mock(1)

2