<h1>Federated Learning - GTEx_V8 Example</h1>
<h2>Populate remote PyGrid nodes with labeled tensors </h2>
In this notebook, we will train a model using federated approach.

**NOTE:** At the time of running this notebook, we were running the grid components in background mode.  

Components:
 - PyGrid Network (http://localhost:5000)
 - PyGrid Node h1 (http://localhost:3000)
 - PyGrid Node h2 (http://localhost:3001)
 
Code implementation for this notebook has been referred from <a href="https://github.com/OpenMined/PySyft/blob/master/examples/tutorials/grid/federated_learning/mnist/Fed.Learning%20MNIST%20%5B%20Part-2%20%5D%20-%20Train%20a%20Model.ipynb">Fed.Learning MNIST [ Part-2 ] - Train a Model</a> tutorial

In [1]:
import syft as sy
from syft.grid.public_grid import PublicGridNetwork
import torch as th
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch
from syft.federated.floptimizer import Optims

### Parameter Cell -->

In [1]:
GRID_ADDRESS = '0.0.0.0'
GRID_PORT = '5000'
N_EPOCS = 20
SAVE_MODEL = True
SAVE_MODEL_PATH = './models'

In [2]:
hook = sy.TorchHook(th)

In [3]:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        self.fc1 = nn.Linear(18420, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 6)

    def forward(self, x):
        # make sure input tensor is flattened
        x = x.view(x.shape[0], -1)
        
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.softmax(self.fc4(x), dim=1)
        return x


device = th.device("cuda:0" if th.cuda.is_available() else "cpu")

if(th.cuda.is_available()):
    th.set_default_tensor_type(th.cuda.FloatTensor)
    
# model = Net()
# model.to(device)
# optimizer = optim.SGD(model.parameters(), lr=0.01)
# criterion = nn.CrossEntropyLoss()

model = Net()
model.to(device)
workers = ['h1', 'h2']
optims = Optims(workers, optim=optim.Adam(params=model.parameters(),lr=0.003))
# criterion = nn.CrossEntropyLoss()

In [4]:
my_grid = PublicGridNetwork(hook,"http://" + GRID_ADDRESS + ":" + GRID_PORT)

In [5]:
my_grid

<syft.grid.public_grid.PublicGridNetwork at 0x158421610>

In [6]:
data = my_grid.search("#X", "#gtex_v8", "#dataset")
target = my_grid.search("#Y", "#gtex_v8", "#dataset")

In [7]:
data.keys()

dict_keys(['h1', 'h2'])

In [8]:
target

{'h1': [(Wrapper)>[PointerTensor | me:75314703439 -> h1:93296785878]
  	Tags: #gtex_v8 #balanced #Y #dataset 
  	Shape: torch.Size([600])
  	Description: The input labels to the GTEx_V8 dataset....],
 'h2': [(Wrapper)>[PointerTensor | me:38129987735 -> h2:6753975204]
  	Tags: #gtex_v8 #balanced #Y #dataset 
  	Shape: torch.Size([600])
  	Description: The input labels to the GTEx_V8 dataset....]}

In [9]:
data = list(data.values())
target = list(target.values())

In [10]:
len(data[1]), len(data[0][0])

(1, 600)

In [11]:
def epoch_total_size(data):
    total = 0
    for i in range(len(data)):
        for j in range(len(data[i])):
            total += data[i][j].shape[0]
            
    return total

In [12]:
data[0][0].location, data[1][0].location

(<Federated Worker id:h1>, <Federated Worker id:h2>)

In [13]:
params=list(model.parameters())
# for i in range(len(params)):
#     print(params[i])

In [14]:
def train(epoch):
    model.train()
    epoch_total = epoch_total_size(data)
    current_epoch_size = 0
    for i in range(len(data)):
        correct = 0
        for j in range(len(data[i])):
            epoch_loss = 0.0
            epoch_acc = 0.0
            
            current_epoch_size += len(data[i][j])
            worker = data[i][j].location
            model.send(worker)
            
            #Call the optimizer for the worker using get_optim
            opt = optims.get_optim(data[i][j].location.id)
            
            opt.zero_grad()
            pred = model(data[i][j])
            loss = F.cross_entropy(pred, target[i][j])
            loss.backward()
            opt.step()
            
            # statistics
            #prob = F.softmax(pred, dim=1)
            top1 = torch.argmax(pred, dim=1)
            ncorrect = torch.sum(top1 == target[i][j])
            
            # Get back loss
            loss = loss.get()
            ncorrect = ncorrect.get()
            
            epoch_loss += loss.item()
            epoch_acc += ncorrect.item()

            epoch_loss /= target[i][j].shape[0]
            epoch_acc /= target[i][j].shape[0]

            model.get()
            
            print('Train Epoch: {} | With {} data |: [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f} | Train Acc: {:.3f}'.format(
                      epoch, worker.id, current_epoch_size, epoch_total,
                            100. *  current_epoch_size / epoch_total, epoch_loss, epoch_acc))

for epoch in range(N_EPOCS):
    train(epoch)

