In [1]:
from dfa_lib import alter_nn as dfa_nn# contains dfa layers
import torch
import numpy as np

from torch import Tensor


class DFANetwork(torch.nn.Module):
    
    def __init__(self, in_features, out_features):
        """
        :param in_features: dimension of input features (784 for MNIST)

        1) Create a ErrorDFA.
        2) Create other LinearDFA-s with a reference to error_layer
        """
        super(DFANetwork, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.f = torch.nn.ReLU()
            
        self.last = dfa_nn.ErrorDFA(100, self.out_features)
        self.linear = dfa_nn.LinearDFA(784, 100, error_layer = self.last)


    def forward(self, inputs):
        """
        forward pass, which is same for conventional feed-forward net
        :param inputs: inputs with shape [batch_size, in_features]
        :return: logit outputs from the network
        """
        
        x = self.f(self.linear(inputs))
        x = self.last(x)

        return x

In [2]:
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import os

BATCH_SIZE = 128

train_loader = DataLoader(datasets.FashionMNIST('./data', train=True, download=True,
                                         transform=transforms.Compose([
                                                                      transforms.ToTensor()
                                                                      ])),
                                         batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(datasets.FashionMNIST('./data', train=False, download=True,
                                        transform=transforms.Compose([
                                                                     transforms.ToTensor()
                                                                     ])),
                                        batch_size=BATCH_SIZE, shuffle=True)


In [3]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_dfa = DFANetwork(in_features=784, out_features = 10)

losses_list = []
accuracy_list = []
val_losses_list = []
val_accuracy_list = []

In [4]:
learning_rate = 1e-3
optimizer_dfa = torch.optim.Adam(model_dfa.parameters(), lr=learning_rate)
loss = torch.nn.CrossEntropyLoss()

epochs = 3

for epoch in tqdm(range(epochs)):

    # model_dfa.train(True)
    
    for idx_batch, (inputs, targets) in enumerate(train_loader):
        
        inputs = inputs.view(inputs.size()[0], -1)
        
        outputs_dfa = model_dfa(inputs)
        loss_fa = loss(outputs_dfa, targets)

        model_dfa.zero_grad()
        loss_fa.backward()
        optimizer_dfa.step()

        if idx_batch % 100 == 0:
            print("iter ", idx_batch, "loss ", loss_fa.item())

  0%|          | 0/3 [00:00<?, ?it/s]

iter  0 loss  2.4744317531585693
iter  100 loss  0.92950838804245
iter  200 loss  0.8423228859901428
iter  300 loss  0.650940477848053
iter  400 loss  0.6328248381614685
iter  0 loss  0.552724301815033
iter  100 loss  0.5623449683189392
iter  200 loss  0.5478946566581726
iter  300 loss  0.42704537510871887
iter  400 loss  0.4071912467479706
iter  0 loss  0.5503165125846863
iter  100 loss  0.43831688165664673
iter  200 loss  0.5131833553314209
iter  300 loss  0.37797069549560547
iter  400 loss  0.42457616329193115


In [5]:
from sklearn.metrics import accuracy_score

def evaluate(model, dataloader, loss_fn):

    y_pred_list = []
    y_true_list = []
    losses = []
    loss_fn = torch.nn.CrossEntropyLoss()
    # проходимся по батчам даталоадера
    for i, batch in enumerate(tqdm(dataloader)):

        # Due to the matrix multiplication we have to control the sizes of matrixes. Here we forgot about last matrix with wrong shape
        if i == 312:
            continue

        # так получаем текущий батч
        X_batch, y_batch = batch

        # выключаем подсчет любых градиентов
        with torch.no_grad():

            # получаем ответы сети на батч
            X_batch = X_batch.view(X_batch.size()[0], -1)
            logits = model(X_batch)
            X_batch, y_batch = X_batch, y_batch

            # вычисляем значение лосс-функции на батче
            loss = loss_fn(logits, y_batch)
            loss = loss.item()

            # сохраняем лосс на текущем батче в массив
            losses.append(loss)

            # для каждого элемента батча понимаем,
            # к какому классу от 0 до 9 отнесла его сеть
            y_pred = torch.argmax(logits, dim=1)

        # сохраняем в массивы правильные ответы на текущий батч
        # и ответы сети на текущий батч
        y_pred_list.extend(y_pred.cpu().numpy())
        y_true_list.extend(y_batch.numpy())

    # считаем accuracy между ответам сети и правильными ответами
    accuracy = accuracy_score(y_pred_list, y_true_list)

    return accuracy, np.mean(losses)

In [6]:
from sklearn.metrics import accuracy_score

def evaluate(model, dataloader, loss_fn):

    y_pred_list = []
    y_true_list = []
    losses = []
    loss_fn = torch.nn.CrossEntropyLoss()
    # проходимся по батчам даталоадера
    for i, batch in enumerate(tqdm(dataloader)):

        # Due to the matrix multiplication we have to control the sizes of matrixes. Here we forgot about last matrix with wrong shape
        if i == 312:
            continue

        # так получаем текущий батч
        X_batch, y_batch = batch
        
        # выключаем подсчет любых градиентов
        with torch.no_grad():

            # получаем ответы сети на батч
            X_batch = X_batch.view(X_batch.shape[0], -1)
            logits = model(X_batch)

            # вычисляем значение лосс-функции на батче
            loss = loss_fn(logits, y_batch)
            loss = loss.item()

            # сохраняем лосс на текущем батче в массив
            losses.append(loss)

            # для каждого элемента батча понимаем,
            # к какому классу от 0 до 9 отнесла его сеть
            y_pred = torch.argmax(logits, dim=1)

        # сохраняем в массивы правильные ответы на текущий батч
        # и ответы сети на текущий батч
        y_pred_list.extend(y_pred.cpu().numpy())
        y_true_list.extend(y_batch.numpy())

    # считаем accuracy между ответам сети и правильными ответами
    accuracy = accuracy_score(y_pred_list, y_true_list)

    return accuracy, np.mean(losses)

In [7]:
acc, ls = evaluate(model_dfa, test_loader, loss)
print('accuracy is: ', round(acc,4))

  0%|          | 0/79 [00:00<?, ?it/s]

accuracy is:  0.8396
