# "Decoupled Neural Interfaces using Synthetic Gradients" paper implementation - https://arxiv.org/pdf/1608.05343.pdf

In [2]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms


class mnist():
    def __init__(self, args):
        train_dataset = dsets.MNIST(root='./data',
                                    train=True,
                                    transform=transforms.ToTensor(),
                                    download=True)

        test_dataset = dsets.MNIST(root='./data',
                                   train=False,
                                   transform=transforms.ToTensor())

        self.train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                        batch_size=args.batch_size,
                                                        shuffle=True)

        self.test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                       batch_size=args.batch_size,
                                                       shuffle=False)
        self.input_dims = 784
        self.num_classes = 10
        self.in_channel = 1
        self.num_train = len(train_dataset)


class cifar10():
    def __init__(self, args):
        transform = self.image_transform()
        train_dataset = dsets.CIFAR10(root='./data/',
                                      train=True,
                                      transform=transform,
                                      download=True)

        test_dataset = dsets.CIFAR10(root='./data/',
                                     train=False,
                                     transform=transforms.ToTensor())

        self.train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                                        batch_size=100,
                                                        shuffle=True)

        self.test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                                       batch_size=100,
                                                       shuffle=False)
        self.num_classes = 10
        self.in_channel = 3
        self.num_train = len(train_dataset)

    def image_transform(self):
        transform = transforms.Compose([
            transforms.Scale(40),
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(28),
            transforms.ToTensor()])
        return transform


In [None]:
import torch
import torch.nn as nn
import numpy as np

class dni_linear(nn.Module):
    def __init__(self, input_dims, num_classes, dni_hidden_size=1024, conditioned=False):
        super(dni_linear, self).__init__()
        self.conditioned = conditioned
        if self.conditioned:
            dni_input_dims = input_dims+num_classes
        else:
            dni_input_dims = input_dims
        self.layer1 = nn.Sequential(
                      nn.Linear(dni_input_dims, dni_hidden_size),
                      nn.BatchNorm1d(dni_hidden_size),
                      nn.ReLU()
                      )
        self.layer2 = nn.Sequential(
                      nn.Linear(dni_hidden_size, dni_hidden_size),
                      nn.BatchNorm1d(dni_hidden_size),
                      nn.ReLU()
                      )
        self.layer3 = nn.Linear(dni_hidden_size, input_dims)

    def forward(self, x, y):
        if self.conditioned:
            assert y is not None
            x = torch.cat((x, y), 1)
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        return out

class dni_Conv2d(nn.Module):
    def __init__(self, input_dims, input_size, num_classes, dni_hidden_size=64, conditioned=False):
        super(dni_Conv2d, self).__init__()
        self.conditioned = conditioned
        if self.conditioned:
            dni_input_dims = input_dims+1
        else:
            dni_input_dims = input_dims

        self.input_size = list(input_size)
        self.label_emb = nn.Linear(num_classes, np.prod(np.array(input_size)))

        self.layer1 = nn.Sequential(
                      nn.Conv2d(dni_input_dims, dni_hidden_size, kernel_size=5, padding=2),
                      nn.BatchNorm2d(dni_hidden_size),
                      nn.ReLU())
        self.layer2 = nn.Sequential( 
                      nn.Conv2d(dni_hidden_size, dni_hidden_size, kernel_size=5, padding=2),
                      nn.BatchNorm2d(dni_hidden_size),
                      nn.ReLU())
        self.layer3 = nn.Sequential(
                      nn.Conv2d(dni_hidden_size, input_dims, kernel_size=5, padding=2))

    def forward(self, x, y):
        if self.conditioned:
            assert y is not None
            y = self.label_emb(y)
            y = y.view([-1, 1]+self.input_size)
            x = torch.cat((x, y), 1)
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        return out