# StreamingCNN

To train deep convolutional neural networks, the input data and the activations need to be kept in memory. Given the limited memory available in current GPUs, this limits the maximum dimensions of the input data. StreamingCNN allows for training a convolutional neural networks while holding only parts of the image in memory. 

**This notebook shows numerical equivalence to a conventional forward and backward pass.**

In [1]:
import torch
from scnn import StreamingCNN

  from .autonotebook import tqdm as notebook_tqdm


Using /home/sharaf/.cache/torch_extensions as PyTorch extensions root...
Emitting ninja build file /home/sharaf/.cache/torch_extensions/cpp_functions/build.ninja...
Building extension module cpp_functions...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
ninja: no work to do.
Loading extension module cpp_functions...


In [2]:
torch.set_printoptions(precision=10)

# Model definition

We initialize an small example network here. All layers are supported, except for feature-wide operations (BatchNormalization).

In [3]:
padding = 0

stream_net = torch.nn.Sequential(
    torch.nn.Conv2d(3, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.MaxPool2d(2),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.Conv2d(16, 16, kernel_size=3, padding=padding), torch.nn.ReLU(),
    torch.nn.MaxPool2d(2))

We enlarge the weights a bit to increase the gradient sizes (better for comparison)

In [4]:
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        layer.weight.data *= 2.5
        layer.bias.data.zero_()

In [5]:
print(stream_net)

Sequential(
  (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (3): ReLU()
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (6): ReLU()
  (7): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (8): ReLU()
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (11): ReLU()
  (12): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
  (13): ReLU()
  (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)


# Configurations

In [6]:
tile_size = 512
img_size = 1024

cuda = True  # execute this notebook on the GPU
verbose = True   # enable / disable logging
dtype = torch.double  # test with double precision

In [7]:
stream_net.type(dtype)
if cuda: stream_net.cuda()

# Configure StreamingCNN

<font color='#FF000'>**IMPORTANT:**</font> setting ```gather_gradients``` to ```True``` makes the class save all the gradients of the intermediate feature maps. This is needed because we want to compare the feature map gradients between streaming and conventional backpropagation. However this also counteracts the memory gains by StreamingCNN. If you want to test the memory efficiency, set ```gather_gradients``` to ```False```.

In [8]:
sCNN = StreamingCNN(stream_net, 
                    tile_shape=(1, 3, tile_size, tile_size), 
                    verbose=True, 
                    gather_gradients=True)

Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1)) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)
MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) 
 Lost(top:0.0, left:0.0, bottom:0.0, right:0.0)

 Output lost Lost(top:0.0, left:0.



If the ```verbose``` flag is ```True``` than StreamingCNN will print for every layer in the network the required overlap that is needed to reconstruct the feature maps and gradients. The higher this is, the more tiles are needed to be inferences. It is always beneficial to increase the tile size as much as possible to make use of all the GPU memory.


# Generate random image and fake label

In [9]:
image = torch.FloatTensor(3, img_size, img_size).normal_(0, 1)
target = torch.tensor(50.)  # large value so we get larger gradients

image = image.type(dtype)
target = target.type(dtype)

if cuda:
    target = target.cuda()
    image = image.cuda()

In [10]:
criterion = torch.nn.BCELoss()

# Run through network using streaming

In [11]:
if image.dim() == 3:
    image = image.unsqueeze(0)
stream_output = sCNN.forward(image); stream_output.max()

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(other, self)


Number of tiles in forward: 9


  0%|          | 0/3 [00:00<?, ?it/s]

100%|██████████| 3/3 [00:00<00:00, 513.38it/s]


tensor(7.9242553953, device='cuda:0', dtype=torch.float64)

In [12]:
stream_output.requires_grad = True

In [13]:
output = torch.sigmoid(torch.mean(stream_output)); output

tensor(0.7766771896, device='cuda:0', dtype=torch.float64,
       grad_fn=<SigmoidBackward>)

In [14]:
loss = criterion(output, target); loss

tensor(-60.8211881103, device='cuda:0', dtype=torch.float64,
       grad_fn=<BinaryCrossEntropyBackward>)

In [15]:
loss.backward()

In [16]:
full_gradients = sCNN.backward(image, stream_output.grad)

401.3333333333333 9 572.0
Number of tiles in backprop: 9


100%|██████████| 3/3 [00:00<00:00, 126.98it/s]


In [17]:
sCNN.disable()

Save the gradients of the Conv2d layer to compare with the conventional method:

In [18]:
streaming_conv_gradients = []

for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        if layer.weight.grad is not None:
            streaming_conv_gradients.append(layer.weight.grad.clone()) 

# Compare to conventional method

We reset the gradients and add hooks to the network to gather the gradients of the intermediate feature maps to compare with streaming.

In [19]:
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        if layer.weight.grad is not None:
            layer.weight.grad.data.zero_()
            layer.bias.grad.data.zero_()

In [20]:
conventional_gradients = []
inps = []

def save_grad(module, grad_in, grad_out):
    global conventional_gradients
    conventional_gradients.append(grad_out[0].clone())
        
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        layer.register_backward_hook(save_grad)

This output should be the same as the streaming output, if so, the loss will also be the same:

In [21]:
conventional_output = stream_net(image); conventional_output.max()

tensor(7.9242553953, device='cuda:0', dtype=torch.float64,
       grad_fn=<MaxBackward1>)

In [22]:
# NOTE: sometimes output can be slightly bigger 
# (if tiles do not fit nicely on input image according to output stride)
# In that case this check may fail.
max_error = torch.abs(stream_output - conventional_output).max().item()

if max_error < 1e-7:
    print("Equal output to streaming")
else:
    print("NOT equal output to streaming"),
    print("error:", max_error)

Equal output to streaming


In [23]:
output = torch.sigmoid(torch.mean(conventional_output)); output

tensor(0.7766771896, device='cuda:0', dtype=torch.float64,
       grad_fn=<SigmoidBackward>)

In [24]:
loss = criterion(output, target); loss

tensor(-60.8211881103, device='cuda:0', dtype=torch.float64,
       grad_fn=<BinaryCrossEntropyBackward>)

In [25]:
loss.backward()

# Compare the gradients of the feature maps

This cell concatenates all the intermediate feature map gradients of the tiles to compare it to the feature map gradients calculated during the conventional method.

In [26]:
equal_eps = 1e-17  # because we are comparing floats the difference is almost never exactly 0

# Find the first Conv2d layer that exists in sCNN.gradients
conv_layer = None
for m in stream_net.modules():
    if isinstance(m, torch.nn.Conv2d) and m in sCNN.gradients:
        conv_layer = m
        break

if conv_layer is not None and len(sCNN.gradients[conv_layer]) == 9:

    layer_dict = dict(stream_net.named_modules())
    i = -1
    for name in layer_dict:
        mod = layer_dict[name]
        if isinstance(mod, torch.nn.Conv2d) and mod in sCNN.gradients and len(name) > 0:
            i += 1

            # StreamingCNN streams from top-left to bottom-right. 
            # First concat the columns, then the rows:
            a = torch.cat((sCNN.gradients[mod][0][0], 
                           sCNN.gradients[mod][1][0], 
                           sCNN.gradients[mod][2][0]), dim=2)
            b = torch.cat((sCNN.gradients[mod][3][0], 
                           sCNN.gradients[mod][4][0], 
                           sCNN.gradients[mod][5][0]), dim=2)
            c = torch.cat((sCNN.gradients[mod][6][0], 
                           sCNN.gradients[mod][7][0], 
                           sCNN.gradients[mod][8][0]), dim=2)
            
            str_grad = torch.cat((a, b, c), dim=1)

            # Compare streaming and conventional:
            max_error = torch.abs(str_grad - conventional_gradients[-(i + 1)][0])
            max_error = max_error.max().item()
            
            if max_error < equal_eps:
                print(name, "feature map gradient - equal to non-streaming")
            else:
                print(name, "feature map - NOT equal, max error:", max_error)
else:
    print("No suitable Conv2d layer with 9 tiles found in sCNN.gradients.")


No suitable Conv2d layer with 9 tiles found in sCNN.gradients.


# Compare the gradients of the conv2d layers

Save the gradients of the conv2d layer to compare with normal SGD:

In [27]:
normal_conv_gradients = []
j = 0
for i, layer in enumerate(stream_net.modules()):
    if isinstance(layer, torch.nn.Conv2d):
        if layer.weight.grad is not None:
            normal_conv_gradients.append(layer.weight.grad) 
            print('Conv layer', j, '\t', layer)
            j += 1

Conv layer 0 	 Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 1 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 2 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 3 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 4 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))
Conv layer 5 	 Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1))


In [28]:
print('Conventional', '\n')

for i in range(len(streaming_conv_gradients)):
    print("Conv layer", i, "\t average gradient size:", 
          float(torch.mean(torch.abs(streaming_conv_gradients[i].data))))

Conventional 



In [29]:
print('Streaming', '\n')
for i in range(len(normal_conv_gradients)):
    print("Conv layer", i, "\t average gradient size:", 
          float(torch.mean(torch.abs(normal_conv_gradients[i].data))))

Streaming 

Conv layer 0 	 average gradient size: 0.5335104974390105
Conv layer 1 	 average gradient size: 1.0143260372301737
Conv layer 2 	 average gradient size: 2.00521271720441
Conv layer 3 	 average gradient size: 0.8939830323267149
Conv layer 4 	 average gradient size: 2.268768575334102
Conv layer 5 	 average gradient size: 1.9859810559858897


In [30]:
for i in range(len(streaming_conv_gradients)):
    diff = torch.abs(streaming_conv_gradients[i].data - normal_conv_gradients[i].data)
    max_diff = diff.max()
    print("Conv layer", i, "\t max difference between kernel gradients:", 
          float(max_diff))

As you can see the difference of the gradients of the conv2d layers between the methods is (almost) numerically equivalent. The small differences are because of loss of significance with the floating points calculations. 