In [1]:
import torch
import numpy as np

from dfa_lib import emu_nn

In [2]:
class LinearDFANetwork(torch.nn.Module):
    """
    Linear feed-forward networks with direct feedback alignment learning
    """
    def __init__(self, in_features, out_features):
        """
        :param in_features: dimension of input features (784 for MNIST)

        1) Create a LinearDFAError.
        2) Create other LinearDFAModules with a reference to error_layer
        """
        super(LinearDFANetwork, self).__init__()

        self.in_features = in_features
        self.out_features = out_features

        self.last = emu_nn.ErrorDFA(128, self.out_features)
        # self.last = torch.nn.Linear(128, self.out_features)
        self.f = torch.nn.ReLU()

        self.conv1 = torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3,3))
        self.pool1 = torch.nn.MaxPool2d(kernel_size=(2,2))
        self.do1 = torch.nn.Dropout(0.25)
        self.flatten = torch.nn.Flatten()
        self.fc1 = emu_nn.LinearDFA(5408, 128, error_layer = self.last, continue_BP = True)        
        # self.fc1 = torch.nn.Linear(5408, 128)


    def forward(self, inputs):
        """
        forward pass, which is same for conventional feed-forward net
        :param inputs: inputs with shape [batch_size, in_features]
        :return: logit outputs from the network
        """

        # first layer
        l0 = self.f(self.conv1(inputs))
        l1 = self.pool1(l0)
        l2 = self.do1(l1)
        l3 = self.flatten(l2)
        l4 = self.f(self.fc1(l3))
        output = self.last(l4)

        return output

In [3]:
# @title Загрузка датасета
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from tqdm.notebook import tqdm
import os

BATCH_SIZE = 32

# загружаем датасет из torchvision
train_data = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
test_data = datasets.MNIST(root="./data", train=False, download=True, transform=transforms.ToTensor())

# делим тренировочную часть на train и val
# в тренировочную выборку отнесем 80% всех картинок
train_size = int(len(train_data) * 0.8)
# в валидационную — остальные 20%
val_size = len(train_data) - train_size

train_data, val_data = torch.utils.data.random_split(train_data, [train_size, val_size])

# заводим даталоадеры, которые будут генерировать батчи
train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

In [4]:
# @title Создание модели
"""
Создание модели и отправка её на GPU/CPU. Число, которое выводит - значения expo для каждого слоя
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_fa = LinearDFANetwork(in_features=28*28, out_features = 10).to(device)
learning_rate = 1e-3
optimizer_fa = torch.optim.Adam(model_fa.parameters(), lr=learning_rate)
loss = torch.nn.CrossEntropyLoss()

"""
списки для хранения лосса и точности на каждом батче для тренировочной и валидационной выборки
"""
losses_list = []
accuracy_list = []
val_losses_list = []
val_accuracy_list = []

In [5]:
torch.cuda.is_available()

False

In [6]:
# @title Точность обучение (evaluate)
from sklearn.metrics import accuracy_score

def evaluate(model, dataloader, loss_fn):
    """
    Вычисляет точность обучения
    """

    y_pred_list = []
    y_true_list = []
    losses = []
    loss_fn = torch.nn.CrossEntropyLoss()
    # проходимся по батчам даталоадера
    for i, batch in enumerate(tqdm(dataloader)):

        # Due to the matrix multiplication we have to control the sizes of matrixes. Here we forgot about last matrix with wrong shap

        # так получаем текущий батч
        X_batch, y_batch = batch

        # выключаем подсчет любых градиентов
        with torch.no_grad():

            # получаем ответы сети на батч
            # X_batch = X_batch.view(X_batch.size()[0], -1)
            logits = model(X_batch.to(device))
            X_batch, y_batch = Variable(X_batch), Variable(y_batch)

            # вычисляем значение лосс-функции на батче
            loss = loss_fn(logits, y_batch.to(device))
            loss = loss.item()

            # сохраняем лосс на текущем батче в массив
            losses.append(loss)

            # для каждого элемента батча понимаем,
            # к какому классу от 0 до 9 отнесла его сеть
            y_pred = torch.argmax(logits, dim=1)

        # сохраняем в массивы правильные ответы на текущий батч
        # и ответы сети на текущий батч
        y_pred_list.extend(y_pred.cpu().numpy())
        y_true_list.extend(y_batch.numpy())

    # считаем accuracy между ответам сети и правильными ответами
    accuracy = accuracy_score(y_pred_list, y_true_list)
    # with np.printoptions(threshold=np.inf):
    #     print(np.array(y_pred_list) - np.array(y_true_list))

    return accuracy, np.mean(losses)

In [7]:
# @title Обучение
epochs = 1
for epoch in tqdm(range(epochs)):

    model_fa.train(True)
    """
    Обучение модели
    """
    for idx_batch, (inputs, targets) in enumerate(tqdm(train_loader)):

        outputs_fa = model_fa(inputs.to(device))
        loss_fa = loss(outputs_fa, targets.to(device))

        losses_list.append(loss_fa.item())
        model_answers = torch.argmax(outputs_fa, dim=1)
        train_accuracy = torch.sum(targets == model_answers.to('cpu')) / len(targets)
        accuracy_list.append(train_accuracy)

        model_fa.zero_grad()
        loss_fa.backward()
        optimizer_fa.step()

        if (idx_batch+1) % 500 == 0:
            print("Средние train лосс и accuracy на последних 500 итерациях:",
                      np.mean(losses_list[-500:]), np.mean(accuracy_list[-500:]), end='\n')

    model_fa.train(False)
    """
    Вычисление точности
    """
    val_accuracy, val_loss = evaluate(model_fa, val_loader, loss_fn=loss)
    val_losses_list.append(val_loss)
    val_accuracy_list.append(val_accuracy)
    print("Эпоха {}/{}: val лосс и accuracy:".format(epoch+1, epochs,),
                  val_loss, val_accuracy, end='\n')

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1500 [00:00<?, ?it/s]

Средние train лосс и accuracy на последних 500 итерациях: 1.7642462046146392 0.569875
Средние train лосс и accuracy на последних 500 итерациях: 0.5755884757637978 0.8333125
Средние train лосс и accuracy на последних 500 итерациях: 0.4009249278157949 0.884875


  0%|          | 0/375 [00:00<?, ?it/s]

Эпоха 1/1: val лосс и accuracy: 0.3228111638128757 0.9103333333333333


In [8]:
acc, ls = evaluate(model_fa, test_loader, loss)
print('accuracy is: ', round(acc,4))

  0%|          | 0/313 [00:00<?, ?it/s]

accuracy is:  0.915
