In [21]:
import numpy as np
import os
import matplotlib.pyplot as plt
from time import time
from time import sleep
import threading, queue
import concurrent.futures
from datetime import timedelta

import torch
import torchvision
import torch.distributed as dist
import torch.multiprocessing as mp

import torch.nn.functional as F
from torchvision import datasets, transforms
from torch import nn, optim

In [26]:
transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
trainset = datasets.MNIST('./data', download=True, train=True, transform=transform)
valset = datasets.MNIST('./data', download=True, train=False, transform=transform)
print(type(trainset))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8192, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(valset, batch_size=8192, shuffle=True)

<class 'torchvision.datasets.mnist.MNIST'>


In [27]:
class MyConvModel(nn.Module):
    def __init__(self, input_num_filters, output_num_filters, filter_size):
        super(MyConvModel, self).__init__()
        self.conv1 = nn.Conv2d(input_num_filters, output_num_filters, kernel_size=(filter_size, filter_size))
        self.pool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        return x

class MyOutputModel(nn.Module):
    def __init__(self, input_num_filters, output_num_filters, filter_size, input_size, hidden_size, ouput_size):
        super(MyOutputModel, self).__init__()
        self.conv1 = nn.Conv2d(input_num_filters, output_num_filters, kernel_size=(filter_size, filter_size))
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, ouput_size)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [30]:
class MyPipelineModel(nn.Module):
    def __init__(self, models, split_size=2048):
        super(MyPipelineModel, self).__init__()
        self.models = models
        self.split_size = split_size
        self.optimizer = optim.Adam(self.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()
        self.epochs = 5
        self.batch_size = 100
        self.running_loss = 0
        
    def run(self, x):
        size = 2
        processes = []
        
        for rank in range(size):
            p = threading.Thread(target=self.train, args=(x, rank, size))
            p.start()
            processes.append(p)
        for p in processes:
            p.join()
        self.backward_q.join()
        self.tensor_q.join()
        
    def train(self, x):
        time0 = time()
        for e in range(self.epochs):
            self.running_loss = 0
            
            for images, labels in x:
                splits = iter(images.split(self.split_size, dim=0))
                s_next = next(splits)
                
                self.optimizer.zero_grad()
                s_prev = self.models[0](s_next)
                ret = []

                for s_next in splits:
                        
                    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
                        output_model = executor.submit(self.models[1].forward, s_prev)
                        input_model = executor.submit(self.models[0].forward, s_next)
                        ret.append(output_model.result())
                        s_prev = input_model.result()

                ret.append(self.models[1](s_prev))
                output = torch.cat(ret)
                loss = self.criterion(output, labels)
                loss.backward()
                self.optimizer.step()
                self.running_loss += loss.item()
                    
            print("Epoch {} - Training loss: {}".format(e, self.running_loss/len(trainloader)))
            print("\nTraining Time (in seconds) =",(time()-time0))
            

In [29]:
models = nn.ModuleList()
models.append(MyConvModel(1, 16, 5))
models.append(MyOutputModel(16, 32, 5, 512, 128, 10))

model = MyPipelineModel(models)
model.train(trainloader)

Epoch 0 - Training loss: 2.1331629157066345

Training Time (in seconds) = 11.415699005126953


KeyboardInterrupt: 