# Parallel Computing

### Checkpointing

##### ThreadLocal

##### Example 1

In [2]:
# import threading

# local = threading.local()

# # Get the current thread
# thread = threading.s()

# # Set an attribute 
# local.x = 1

# # Accessing local.x from the same thread 
# # will give you the same value
# print(local.x) # 1

# # But a different thread will have a 
# # different copy of that variable!

# def func():
#     print(local.x) # None! Different copy
    
# thread2 = threading.Thread(target=func)
# thread2.start()

##### Example 1.1

##### Example 2

In [1]:
import torch

In [12]:
class WithoutGradient(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        return input.detach()
    
    @staticmethod
    def backward(ctx, grad_input):
        return grad_input

In [13]:
x = torch.tensor(1., requires_grad=True)

In [14]:
(WithoutGradient.apply(x) + 23123).backward()

In [15]:
x.grad

tensor(1.)

##### Example 2

In [99]:
# timeline = []

# class Log(torch.autograd.Function):
#     @staticmethod
#     def forward(ctx, name, x):
#         ctx.name = name
#         timeline.append(f"{name}:forward")
#         return x
    
#     @staticmethod
#     def backward(ctx, grad_output):
#         name = ctx.name
#         timeline.append(f"{name}:backward")
#         return None, grad_output

In [5]:
# function = partial(Log.apply, "x")

### `Recompute`

##### Example 3

In [134]:
phony = torch.empty(0, requires_grad=True)
input = torch.randn((2, 4), requires_grad=True)

In [135]:
function = nn.Linear(4, 2)

In [136]:
recomputed = collections.deque(maxlen=1)

In [137]:
from functools import partial

import torch
from torch import nn

In [138]:
function, recomputed

(Linear(in_features=4, out_features=2, bias=True), deque([], maxlen=1))

In [139]:
phony, input

(tensor([], requires_grad=True),
 tensor([[ 0.0984, -2.1375, -1.2984,  0.5310],
         [ 0.8394,  0.4371, -0.2785,  0.1812]], requires_grad=True))

Reimplement the activation recomputation step in gradient checkpointing and store the output and input tensors in `recomputed`

**Hint**: `recomputed.append((x, y))`

In [140]:
class Recompute(torch.autograd.Function):
    @staticmethod
    def forward(ctx, phony, recomputed, function, input):
        ctx.recomputed = recomputed
        ctx.function, ctx.input = function, input
        return phony
    
    @staticmethod
    def backward(ctx, grad_output):
        input, function = ctx.input, ctx.function
        input_leaf = input.detach().requires_grad_(
            input.requires_grad
        )

        with torch.enable_grad():
            output = function(input_leaf)
        
        ctx.recomputed.append((output, input_leaf))
        grad_input = [None, None, None]
        
        if input_leaf.grad is not None:
            grad_input.extend([input_leaf.grad])
        else:
            grad_input.extend([None])
        
        return tuple(grad_input)

In [141]:
phony = Recompute.apply(
    phony, recomputed,
    function, input
)

In [142]:
phony.sum().backward()

In [143]:
recomputed

deque([(tensor([[ 0.4161,  0.5159],
                [-0.0299, -0.6883]], grad_fn=<AddmmBackward0>),
        tensor([[ 0.0984, -2.1375, -1.2984,  0.5310],
                [ 0.8394,  0.4371, -0.2785,  0.1812]], requires_grad=True))],
      maxlen=1)

### `Checkpoint`

##### Example 1

In [144]:
import torch

Reimplement the Checkpoint in Checkpointing, where `recomputed` stores both the recomputed activations and the input leaf.

**Hint**: `recomputed.pop()`, `array.extend([x])`

In [145]:
class Checkpoint(torch.autograd.Function):
    @staticmethod
    def forward(ctx, phony, recomputed, function, input):
        ctx.recomputed = recomputed
        
        with torch.no_grad():
            output = function(input)
        
        return output

    def backward(ctx, grad_output):
        output, input_leaf = ctx.recomputed.pop()
        
        if output.requires_grad:
            torch.autograd.backward(output, grad_output)
        
        grad_input = [None, None, None]
        if input_leaf.grad is not None:
            grad_input.extend([input_leaf.grad])
        else:
            grad_input.extend([None])
        return tuple(grad_input)

In [146]:
phony

tensor([], grad_fn=<RecomputeBackward>)

In [147]:
recomputed

deque([(tensor([[ 0.4161,  0.5159],
                [-0.0299, -0.6883]], grad_fn=<AddmmBackward0>),
        tensor([[ 0.0984, -2.1375, -1.2984,  0.5310],
                [ 0.8394,  0.4371, -0.2785,  0.1812]], requires_grad=True))],
      maxlen=1)

In [148]:
function, input

(Linear(in_features=4, out_features=2, bias=True),
 tensor([[ 0.0984, -2.1375, -1.2984,  0.5310],
         [ 0.8394,  0.4371, -0.2785,  0.1812]], requires_grad=True))

In [149]:
output = Checkpoint.apply(phony, recomputed, function, input)

In [150]:
output

tensor([[ 0.4161,  0.5159],
        [-0.0299, -0.6883]], grad_fn=<CheckpointBackward>)

In [151]:
output.sum().backward()

In [152]:
input.grad

tensor([[-0.6793, -0.1394, -0.5098,  0.7668],
        [-0.6793, -0.1394, -0.5098,  0.7668]])

### `Checkpointing`

##### Example 1

In [2]:
import torch

In [None]:
phony > checkpoint.forward()

In [None]:
140379543834512

In [5]:
from collections import deque

In [4]:
class Checkpointing:
    def __init__(self, function, input):
        self.function = function
        self.input = input
        
        self.recomputed = deque(maxlen=1)
    
    def checkpoint(self,):
        input = self.input
        output = Checkpoint.apply()
    
    def recompute(self):
        pass

In [3]:
def checkpoint(function, input):
    ckp = Checkpointing(function, input)
    output = ckp.checkpoint()
    ckp.recompute()
    return output