### `Copy`

In [5]:
import torch

##### Example 1

In [6]:
x = torch.tensor([1., 2., 3.], requires_grad=True)

In [7]:
import torch

Write a function that returns the saved input from the forward pass in the backward pass.

In [8]:
class Function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.input = input
        return input
    
    @staticmethod
    def backward(ctx, grad_output):
        return ctx.input

In [9]:
out = Function.apply(x)

In [10]:
out

tensor([1., 2., 3.], grad_fn=<FunctionBackward>)

In [11]:
out.sum(dim=-1).backward()

In [12]:
x.grad == x

tensor([True, True, True])

##### Example 2

In [13]:
from contextlib import contextmanager

In [14]:
@contextmanager
def a():
    print("enter_1")    
    yield "yield_1"
    print("exit_1")

In [15]:
@contextmanager
def b():
    print("enter_2")
    yield "yield_2"    
    print("exit_2")

What is the output? Explain

In [16]:
with a(), b():
    print("hello")

enter_1
enter_2
hello
exit_2
exit_1


**Explain**
- `with a(), b()`: Python's with statement supports multiple context managers. The `__enter__` methods of each context manager are invoked in the order they are listed. So, `a()`'s `__enter__` method runs first, printing *enter_1*, then `b()`'s `__enter__` method runs, printing *enter_2*.

- The code inside the with block is then executed. This prints hello.

- When the with block is exited, the `__exit__` methods of each context manager are invoked in the opposite order to their `__enter__` methods. So, `b()`'s `__exit__` method runs first, printing *exit_2*, then `a()`'s `__exit__` method runs, printing *exit_1*.

##### Example 3

In [None]:
x = torch.tensor([69])

In [None]:
x.storage()

 69
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 1]

In [None]:
x.new_empty([0]).set_(x.storage())

tensor([69])

##### Example 4

In [None]:
with torch.cuda.device(1):
    a = torch.tensor([1., 2.], device=cuda)

### `Wait`

In [2]:
import torch

In [None]:
def wait_stream(source_stream, target_stream):
    source_stream.wait_stream()

In [None]:
source_stream.wait_stream(target_stream)

In [None]:
class Wait(torch.autograd.Function):
    @staticmethod
    def forward(ctx, prev_stream, next_stream, input):
        ctx.prev_stream = prev_stream
        ctx.next_stream = next_stream
        
        prev_stream.wait_stream(next_stream)
        return 

In [17]:
xs = [1, 2, 3, 4]

In [18]:
xs[:]

[1, 2, 3, 4]