In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
from time import time
from time import sleep
import itertools
from datetime import timedelta
import diffdist.functional as distops

import torch
import torchvision
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.autograd import Variable

import torch.nn.functional as F
from torchvision import datasets, transforms
from torch import nn, optim

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('./data', download=True, train=True, transform=transform)
valset = datasets.MNIST('./data', download=True, train=False, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=512, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=512, shuffle=True)

In [None]:
class MyConvModel(nn.Module):
    def __init__(self, input_num_filters, output_num_filters, filter_size):
        super(MyConvModel, self).__init__()
        self.conv1 = nn.Conv2d(input_num_filters, output_num_filters, kernel_size=(filter_size, filter_size))
        self.pool = nn.MaxPool2d(2, 2)
        self.optimizer = optim.Adam(self.parameters(), lr=0.001)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        return x

class MyOutputModel(nn.Module):
    def __init__(self, input_size, hidden_size, ouput_size):
        super(MyOutputModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, ouput_size)
        self.optimizer = optim.Adam(self.parameters(), lr=0.001)
        
    def forward(self, x):
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
class MyPipelineModel(nn.Module):
    def __init__(self, models, num_worker, split_size=32):
        super(MyPipelineModel, self).__init__()
        self.models = models
        self.num_worker = num_worker
        self.split_size = split_size
        self.optimizers = [optim.Adam(self.models[0].parameters(), lr=0.001), optim.Adam(self.models[1].parameters(), lr=0.001), optim.Adam(self.models[2].parameters(), lr=0.001)]
        self.criterion = nn.CrossEntropyLoss()
        self.epochs = 5
        #self.batch_size = 100
        self.running_loss = 0
        
    def run(self, x):
        processes = []
        
        for rank in range(self.num_worker):
            p = mp.Process(target=self.init_process, args=(rank, self.num_worker, x, self.forward))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()
        
    def init_process(self, rank, size, x, fn, backend='gloo'):
        torch.set_num_threads(2)
        os.environ['GLOO_SOCKET_IFNAME']='lo'
        os.environ['MASTER_ADDR'] = 'localhost'
        os.environ['MASTER_PORT'] = '3000'
        dist.init_process_group(backend, rank=rank, world_size=size)
        fn(rank, size, x)
        
    def forward(self, rank, size, x):
        time0 = time()
        for e in range(self.epochs):
            if rank == self.num_worker-1:
                self.running_loss = 0
            
            for images, labels in x:
                rank = dist.get_rank()
                splits = iter(images.split(self.split_size, dim=0))
                s_next = next(splits)
                splits_value = [None, torch.zeros((32, 16, 12, 12)), torch.zeros((32, 32, 4, 4))]
                self.optimizers[rank].zero_grad()
                
                if rank == self.num_worker-1:
                    ret = []
                
                
                if rank == self.num_worker-1:
                    dist.recv(tensor=splits_value[rank], src=rank-1)
                    splits_value[rank] = Variable(splits_value[rank], requires_grad = True)
                    output = self.models[rank](splits_value[rank])
                    ret.append(output)
                    print(splits_value[rank].requires_grad)

                elif rank == 0:
                    output = self.models[0](s_next)
                    dist.send(tensor=output, dst=1)
                    print(output.requires_grad)
                else:
                    dist.recv(tensor=splits_value[rank], src=rank-1)
                    splits_value[rank] = Variable(splits_value[rank], requires_grad = True)
                    output = self.models[rank](splits_value[rank])
                    dist.send(tensor=output, dst=rank+1)
                
                
                for s_next in splits:
                    if rank == self.num_worker-1:
                        dist.recv(tensor=splits_value[rank], src=rank-1)
                        splits_value[rank] = Variable(splits_value[rank], requires_grad = True)
                        output = self.models[rank](splits_value[rank])
                        ret.append(output)

                    elif rank == 0:
                        output = self.models[0](s_next)
                        dist.send(tensor=output, dst=1)
                    else:
                        dist.recv(tensor=splits_value[rank], src=rank-1)
                        splits_value[rank] = Variable(splits_value[rank], requires_grad = True)
                        output = self.models[rank](splits_value[rank])
                        dist.send(tensor=output, dst=rank+1)
                        
                if rank == self.num_worker-1:
                    output = torch.cat(ret)
                    loss = self.criterion(output, labels)
                    
                    for i in range(self.num_worker-1):
                        dist.send(tensor=loss, dst=i, tag=i)
                else:
                    loss = torch.zeros((1))
                    dist.recv(tensor=loss, src=self.num_worker-1, tag=rank)
                    
                    loss = Variable(loss, requires_grad = True)

                print("Rank :", rank, torch.autograd.grad(loss, self.models[rank].parameters()))

                self.optimizers[rank].step()
                self.running_loss += loss.item()
            
            if rank == self.num_worker-1:
                print("Epoch {} - Training loss: {}".format(e, self.running_loss/len(trainloader)))
                print("\nTraining Time (in seconds) =",(time()-time0))

In [None]:
models = nn.ModuleList()
models.append(MyConvModel(1, 16, 5))
models.append(MyConvModel(16, 32, 5))
models.append(MyOutputModel(512, 128, 10))

model = MyPipelineModel(models, len(models))
model.run(trainloader)