# Parallel Computing

### Checkpointing

##### ThreadLocal

##### Example 1

In [2]:
# import threading

# local = threading.local()

# # Get the current thread
# thread = threading.s()

# # Set an attribute 
# local.x = 1

# # Accessing local.x from the same thread 
# # will give you the same value
# print(local.x) # 1

# # But a different thread will have a 
# # different copy of that variable!

# def func():
#     print(local.x) # None! Different copy
    
# thread2 = threading.Thread(target=func)
# thread2.start()

##### Example 1.1

##### Example 2

In [1]:
import torch

In [12]:
class WithoutGradient(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.detach()
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input

In [13]:
x = torch.tensor(1., requires_grad=True)

In [14]:
(WithoutGradient.apply(x) + 23123).backward()

In [15]:
x.grad

tensor(1.)

##### Example 3

In [196]:
phony = torch.empty(0, requires_grad=True)

In [197]:
x = torch.tensor(1., requires_grad=True)

In [198]:
timeline = []

In [199]:
class Log(torch.autograd.Function):
    @staticmethod
    def forward(ctx, name, x):
        ctx.name = name
        timeline.append(f"{name}:forward")
        return x
    
    @staticmethod
    def backward(ctx, grad_output):
        name = ctx.name
        timeline.append(f"{name}:backward")
        return None, grad_output

In [200]:
x, phony

(tensor(1., requires_grad=True), tensor([], requires_grad=True))

In [201]:
timeline

[]

In [202]:
from functools import partial
import torch

In [203]:
class Recompute(torch.autograd.Function):
    @staticmethod
    def forward(ctx, phony, function, input):
        ctx.recomputed = []
        ctx.function = function
        ctx.input = input
        return phony
    
    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.input
        function = ctx.function
        input_leaf = input.detach().requires_grad_(input.requires_grad)
        
        with torch.enable_grad():
            output = function(input_leaf)
        
        ctx.recomputed.append((output, input_leaf))
        grad_input = [None, None]
        
        if input_leaf.grad is not None:
            grad_input.extend([input_leaf.grad])
        else:
            grad_input.extend([None])
        
        return tuple(grad_input)

In [204]:
phony = Recompute.apply(
    phony,
    partial(Log.apply, "x"), # function
    x # input
)

In [205]:
phony.sum().backward()

In [206]:
timeline

['x:forward']