# Tutorial
This tutorial will walk you through the key features and their usage of our costomized DataParallel module. 
The code is tested on pytorch version 0.4.0 and python 3.6. Should work on other pytorch versions. 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

import nn as mynn

## Let's begin by defining some network classes

In [None]:
# simple networks
class simple_networkA(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(5, 2)
        
    def forward(self, input):
        return self.model(input)

class simple_networkB(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(2, 1)
        
    def forward(self, input):
        return self.model(input)
    
# complex network
class complex_networkA(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(5, 2)
        
    def forward(self, input):
        output = self.model(input)
        
        return_dict = {}
        return_dict['outputs'] = {}
        return_dict['outputs']['out1'] = output
        return_dict['outputs']['out2'] = output
        return_dict['second_output'] = output
        return return_dict
    
class complex_networkB(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Linear(2, 1)
        
    def forward(self, input):
        output = self.model(input)
        return output

## Test case one
Here we want use a network to extract features, and another network to use these features to make prediction.
We show how to avoid redundant scattering and gathering of the features in this case. 

By setting keyword `gather_to_one_device=False`, we avoid gathering all outputs to one single devices and return a tuple containing all the outputs. The tuple is safe because we have sorted the outputs according to their GPU device ids.

By setting keyword `inputs_are_scattered=True`, we assume the inputs are already scattered into all GPUS and thus skip the scattering step. The inputs must be sorted according to GPU device ids, matching the order of outputs when `gather_to_one_device=False`. 

In [None]:
class TestCaseOne(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.netA = simple_networkA()
        self.netB = simple_networkB()
        
        self.optimizer_A = optim.Adam(self.netA.parameters(), lr=1e-4, betas=(0.5, 0.9))
        self.optimizer_B = optim.Adam(self.netB.parameters(), lr=1e-4, betas=(0.5, 0.9))
        
    def convert_data_parallel(self):
        self.netA = mynn.DataParallel(self.netA.cuda(), gather_to_one_device=False) # don't gather to one device
        self.netB = mynn.DataParallel(self.netB.cuda(), inputs_are_scattered=True) # assume inputs are scattered
        
    def forward(self, input):
        A_out = self.netA(input)
        B_out = self.netB(A_out)
        
        loss = B_out.mean()
        loss.backward()
        
        self.optimizer_A.step()
        
        # Try print out A's type
        print('type of A_out is ', type(A_out)) # now it's a tuple
        print('size of A_out content is ', A_out[0].size()) # should be (16 / num_gpus, 2)
        
        return B_out
        
test_model = TestCaseOne()
test_model.convert_data_parallel()

input = torch.ones(16, 5).cuda()
output = test_model(input)

## Test case two
Now consider a more complicated case. Suppose a network not only extracts features but also performs some kind of prediction. Since the prediction won't be used by any other network, we wish to gather them. But for the features, we still want to keep them seperate. In other words, we only want some of the outputs be seperate and gather the rest. 

This can be done by assigning `except_keywords=[...]`. See the below example 

In [None]:
class TestCaseTwo(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.netA = complex_networkA()
        self.netB = complex_networkB()
        
        self.optimizer_A = optim.Adam(self.netA.parameters(), lr=1e-4, betas=(0.5, 0.9))
        self.optimizer_B = optim.Adam(self.netB.parameters(), lr=1e-4, betas=(0.5, 0.9))
        
    def convert_data_parallel(self):
        self.netA = mynn.DataParallel(self.netA.cuda(), except_keywords=['out1', 'out2']) # gather all except out1,out2
        self.netB = mynn.DataParallel(self.netB.cuda(), inputs_are_scattered=True)
        
    def forward(self, input):
        A_out = self.netA(input)
        
        # let's see the outputs of the except_keywords
        out1 = A_out['outputs']['out1']
        out2 = A_out['outputs']['out2']
        print('out1 type is {}; out2 type is {}'.format(type(out1), type(out2)))
        print('out1 element size is ', out1[0].size())
        
        # check out the other outputs
        second_output = A_out['second_output']
        print('second output type is {}'.format(second_output.type()))
        print('second output size is ', second_output.size())
        
        # manipulating with the tuple outputs
        B_in = tuple(o1 + o2 for o1, o2 in zip(out1, out2)) # time cost is very small.
        
        B_out = self.netB(B_in)
        return B_out

test_model = TestCaseTwo()
test_model.convert_data_parallel()

input = torch.ones(16, 5).cuda()
output = test_model(input)

## Notes
We have another `sort_and_gather` option which sorts the output according to GPU device ids and gather them together. 

For example, when a network B uses features extracted by a network A to make prediction and gather the outputs, it is important that the output of B is aligned to the label of A, so we probably need this option. But you can always do it by adding the label to `except_keywords` and adding the label to the output. And keep in mind it MUST be a cuda tensor, because we have to sort according to GPU ids. 