<h1>Federated Learning - GTEx_V8 Example</h1>
<h2>Populate remote PyGrid nodes with labeled tensors </h2>
In this notebook, we will train a model using federated approach.

**NOTE:** At the time of running this notebook, we were running the grid components in background mode.  

Components:
 - PyGrid Network (http://localhost:5000)
 - PyGrid Node h1 (http://localhost:3000)
 - PyGrid Node h2 (http://localhost:3001)
 
Code implementation for this notebook has been referred from <a href="https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/grid/federated_learning/mnist/Fed.Learning%20MNIST%20%5B%20Part-2%20%5D%20-%20Train%20a%20Model.ipynb">Fed.Learning MNIST [ Part-2 ] - Train a Model</a> tutorial

In [34]:
import syft as sy
from syft.grid.public_grid import PublicGridNetwork
import torch as th
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [35]:
hook = sy.TorchHook(th)
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.fc1 = nn.Linear(18420, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 6)

    def forward(self, x):
        # make sure input tensor is flattened
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.log_softmax(self.fc4(x), dim=1)
        return x


device = th.device("cuda:0" if th.cuda.is_available() else "cpu")

if(th.cuda.is_available()):
    th.set_default_tensor_type(th.cuda.FloatTensor)
    
# model = Net()
# model.to(device)
# optimizer = optim.SGD(model.parameters(), lr=0.01)
# criterion = nn.CrossEntropyLoss()

model = Net()
model.to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()



In [36]:
GRID_ADDRESS = 'localhost'
GRID_PORT = '5000'

my_grid = PublicGridNetwork(hook,"http://" + GRID_ADDRESS + ":" + GRID_PORT)

In [37]:
data = my_grid.search("#X", "#gtex_v8", "#dataset")
target = my_grid.search("#Y", "#gtex_v8", "#dataset")

In [38]:
data

{'h1': [(Wrapper)>[PointerTensor | me:93469325155 -> h1:6072499217]
  	Tags: #balanced #dataset #gtex_v8 #X 
  	Shape: torch.Size([600, 18420])
  	Description: The input datapoints to the GTEx_V8 dataset....],
 'h2': [(Wrapper)>[PointerTensor | me:84007623869 -> h2:62422029600]
  	Tags: #balanced #gtex_v8 #X #dataset 
  	Shape: torch.Size([600, 18420])
  	Description: The input datapoints to the GTEx_V8 dataset....]}

In [39]:
target

{'h1': [(Wrapper)>[PointerTensor | me:29868727086 -> h1:55841676130]
  	Tags: #balanced #dataset #Y #gtex_v8 
  	Shape: torch.Size([600])
  	Description: The input labels to the GTEx_V8 dataset....],
 'h2': [(Wrapper)>[PointerTensor | me:55464176909 -> h2:46460376007]
  	Tags: #balanced #gtex_v8 #Y #dataset 
  	Shape: torch.Size([600])
  	Description: The input labels to the GTEx_V8 dataset....]}

In [40]:
data = list(data.values())
target = list(target.values())

In [41]:
len(data[1]), len(data[0][0])

(1, 600)

In [42]:
def epoch_total_size(data):
    total = 0
    for i in range(len(data)):
        for j in range(len(data[i])):
            total += data[i][j].shape[0]
            
    return total

In [43]:
N_EPOCS = 3
SAVE_MODEL = True
SAVE_MODEL_PATH = './models'

def train(epoch):
    model.train()
    epoch_total = epoch_total_size(data)
    current_epoch_size = 0
    for i in range(len(data)):
        for j in range(len(data[i])):
            current_epoch_size += len(data[i][j])
            print('STEP 1')
            worker = data[i][j].location
            print('STEP 2')
            model.send(worker)
            print('STEP 3')
            optimizer.zero_grad()
            pred = model(data[i][j])
            print('STEP 4')
            loss = criterion(pred, target[i][j])
            print('STEP 5')
            loss.backward()
            print('STEP 6')
            optimizer.step()
            print('STEP 7')
            model.get()
            print('STEP 8')
            loss = loss.get()
            print('STEP 9')
            print('Train Epoch: {} | With {} data |: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                      epoch, worker.id, current_epoch_size, epoch_total,
                            100. *  current_epoch_size / epoch_total, loss.item()))

for epoch in range(N_EPOCS):
    train(epoch)

STEP 1
STEP 2
STEP 3
STEP 4
STEP 5
STEP 6
STEP 7
STEP 8
STEP 9
STEP 1
STEP 2
STEP 3
STEP 4
STEP 5
STEP 6
STEP 7
STEP 8
STEP 9
STEP 1
STEP 2
STEP 3
STEP 4
STEP 5
STEP 6
STEP 7
STEP 8
STEP 9
STEP 1
STEP 2
STEP 3
STEP 4
STEP 5
STEP 6
STEP 7
STEP 8
STEP 9
STEP 1
STEP 2
STEP 3
STEP 4
STEP 5
STEP 6
STEP 7
STEP 8
STEP 9
STEP 1
STEP 2
STEP 3
STEP 4
STEP 5
STEP 6
STEP 7
STEP 8
STEP 9
