In [37]:
import torch as t
import torch.nn as nn
import math
import torch.nn.functional as F
from torchvision import datasets, transforms

In [3]:
class Binarize(t.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return t.sign(input + 1e-20)

    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.saved_tensors[0]
        grad_output[input > 1] = 0
        grad_output[input < -1] = 0
        return grad_output

In [5]:
class BinarizedLinear(nn.Module):
    def __init__(self, in_features, out_features, binarize_input=True):
        super(BinarizedLinear, self).__init__()
        self.binarize_input = binarize_input
        self.weight = nn.Parameter(t.Tensor(out_features, in_features))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x):
        if self.binarize_input:
            x = Binarize.apply(x)
        w = Binarize.apply(self.weight)
        out = t.matmul(x, w.t())
        return out

In [6]:
model = nn.Sequential(BinarizedLinear(784, 2048, False),
                      nn.BatchNorm1d(2048),
                      BinarizedLinear(2048, 2048),
                      nn.BatchNorm1d(2048),
                      BinarizedLinear(2048, 2048),
                      nn.Dropout(0.5),
                      nn.BatchNorm1d(2048),
                      nn.Linear(2048, 10))


In [7]:
TRAIN_BATCH_SIZE = 64
TEST_BATCH_SIZE = 1000
LR = 0.01
EPOCH = 100
LOG_INTERVAL = 100


In [8]:

# Adam优化器
optimizer = t.optim.Adam(model.parameters(), lr=LR)
# 学习率调整器
scheduler = t.optim.lr_scheduler.StepLR(optimizer, step_size=40, gamma=0.1)

In [38]:
transforms_compose = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader = t.utils.data.dataloader(
    datasets.FashionMNIST('./data', train=True, download=False, transform=transforms_compose),
    batch_size=TRAIN_BATCH_SIZE, shuffle=True)
test_loader = t.utils.data.dataloader(datasets.FashionMNIST('./data', train=False, transform=transforms_compose),
                                      batch_size=TEST_BATCH_SIZE, shuffle=False)



Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ./FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/26421880 [00:00<?, ?it/s]

Extracting ./FashionMNIST/FashionMNIST/raw/train-images-idx3-ubyte.gz to ./FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ./FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/29515 [00:00<?, ?it/s]

Extracting ./FashionMNIST/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ./FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ./FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/4422102 [00:00<?, ?it/s]

Extracting ./FashionMNIST/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ./FashionMNIST/FashionMNIST/raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/5148 [00:00<?, ?it/s]

Extracting ./FashionMNIST/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ./FashionMNIST/FashionMNIST/raw

Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


TypeError: 'module' object is not callable