In [152]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
from torchvision import transforms
import random

In [146]:
def prepare_data(x:torch.Tensor, y:torch.Tensor):
    x_real = x.clone()
    y_onehot_real = torch.zeros((x.shape[0],10),dtype=x.dtype,device=x.device)
    batches = torch.arange(x.shape[0])
    y_onehot_real[batches,y] = 1.0
    x_real[batches,:10] = y_onehot_real

    x_fake = x.clone()
    s = torch.arange(10)[None,:].repeat((x.shape[0],1))
    mask = torch.ones_like(s)
    mask[batches,y] = 0
    s = s[mask.type(torch.bool)].reshape(x.shape[0],9)
    idxs = torch.randint(0,9,(x.shape[0],))
    samples = s[batches,idxs]
    y_onehot_fake = torch.zeros((x.shape[0],10),dtype=x.dtype,device=x.device)
    y_onehot_fake[batches,samples] = 1.0
    x_fake[batches,:10] = y_onehot_fake
    return x_real, x_fake

def prepare_data_test(x:torch.Tensor, label):
    x_test = x.clone()
    y_onehot = torch.zeros((x.shape[0],10),dtype=x.dtype,device=x.device)
    batches = torch.arange(x.shape[0])
    y_onehot[batches,label] = 1.0
    x_test[batches, :10] = y_onehot
    return x_test

# This class acts our baseline layer
class Layer(nn.Module):
    def __init__(self, input_dim, output_dim, device, threshold=2.0, epochs=1000):
        super().__init__()
        self.lin = nn.Linear(input_dim,output_dim)
        self.relu = nn.ReLU()
        self.threshold = threshold
        self.epochs = epochs
        self.optim = torch.optim.Adam(self.parameters(),lr=0.03)
        self.device = device

    def forward(self,x:torch.Tensor):
        x = x / (x.norm(p=2,dim=1,keepdim=True)+1e-4)
        x = self.lin(x)
        x = self.relu(x)
        return x

    def forward_forward(self,x_pos, x_neg):
        for epoch in (range(self.epochs)):
            h_pos = self.forward(x_pos)
            h_neg = self.forward(x_neg)
            good_pos = (self.threshold-h_pos**2).mean(dim=1)
            neg_pos = (h_neg**2-self.threshold).mean(dim=1)
            self.optim.zero_grad()
            loss = torch.log(1+torch.exp(torch.cat([good_pos,neg_pos]))).mean()
            loss.backward()
            self.optim.step()
            if epoch % 100 == 0:
                print(f'epoch {epoch}: loss = {loss}')
        return self.forward(x_pos).detach(),self.forward(x_neg).detach()

    def predict(self,h:torch.Tensor):
        h = h / h.norm(p=2,dim=1,keepdim=True)
        h = self.lin(h)
        h = self.relu(h)
        goodness = (h**2-self.threshold).mean(dim=1)
        return h, goodness


# this class define entire model
class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, device, threshold=2.0, epochs=100):
        super().__init__()
        self.device = device
        self.layers = []
        self.num_layers = len(hidden_dims) + 1
        self.epochs = epochs
        assert len(hidden_dims) >= 1
        self.layers.append(Layer(input_dim,hidden_dims[0],device,threshold).to(device))
        for i in range(1,len(hidden_dims)):
            self.layers.append(Layer(hidden_dims[i-1],hidden_dims[i],device,threshold).to(device))

    def train(self,trainLoader):
        x,y = next(iter(trainLoader))
        x_pos, x_neg = prepare_data(x,y)
        h_pos = x_pos.to(self.device)
        h_neg = x_neg.to(self.device)
        for i,layer in enumerate(self.layers):
            print(f"Layer {i+1}:")
            #for x,y in tqdm(trainLoader):
            h_pos, h_neg = layer.forward_forward(h_pos,h_neg)
    def test(self,testLoader,name):
        accuracy = 0.0
        x, label = next(iter(testLoader))
        #for (x,label) in testLoader:
        goodness_targets = []
        for i in range(10):
            z = prepare_data_test(x,i)
            z = z.to(device)
            goodness_layers = []
            for j,layer in enumerate(self.layers):
                z, goodness = layer.predict(z)
                goodness_layers.append(goodness[:,None])

            goodness_targets.append(torch.cat(goodness_layers,dim=1).mean(dim=1,keepdim=True))

        goodness_targets = torch.cat(goodness_targets,dim=1)
        max_index = goodness_targets.argmax(dim=1)
        accuracy = (max_index.cpu()==label).sum().item()/len(label)
        print(f"Accuracy of Model on {name}: {accuracy*100}%")

In [147]:
# Load MNIST Dataset
transform = Compose([ToTensor(), Lambda(lambda x: torch.flatten(x))])
trainset = MNIST('.',train=True,download=True,transform=transform)
testset = MNIST('.',train=False, download=True,transform=transform)

batch_size = 60000
trainLoader = DataLoader(trainset,batch_size,shuffle=True)
testLoader = DataLoader(testset,batch_size,shuffle=True)

In [148]:
# TRain Model
input_dim = 784
hidden_dims = [512,512]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MLP(input_dim, hidden_dims,device,epochs=20).to(device)
model.train(trainLoader)

Layer 1:
epoch 0: loss = 1.1267484426498413
epoch 100: loss = 0.6906740665435791
epoch 200: loss = 0.6093816757202148
epoch 300: loss = 0.5042091012001038
epoch 400: loss = 0.4367710053920746
epoch 500: loss = 0.3928852379322052
epoch 600: loss = 0.36190348863601685
epoch 700: loss = 0.33817917108535767
epoch 800: loss = 0.3184700310230255
epoch 900: loss = 0.3012816309928894
Layer 2:
epoch 0: loss = 1.1266560554504395
epoch 100: loss = 0.4548075497150421
epoch 200: loss = 0.3663649260997772
epoch 300: loss = 0.3197694718837738
epoch 400: loss = 0.2913164794445038
epoch 500: loss = 0.2713039815425873
epoch 600: loss = 0.2559657692909241
epoch 700: loss = 0.24394433200359344
epoch 800: loss = 0.23424845933914185
epoch 900: loss = 0.22619059681892395


In [149]:
# Test Model
model.test(trainLoader,'train set')
model.test(testLoader,'test set')

Accuracy of Model on train set: 91.55499999999999%
Accuracy of Model on test set: 91.75999999999999%


تابع هزینه به شیوه تعریف شده خوب است چون در یکی از ترم های آن برای بیشتری شدن آستانه نسبت به داده درست هزینه قرارده شده و به صورت معکوس آن برای داده غلط. در نتیجه با کمک شدن این هزینه دقت شبمه بالا میرود. همچنین برای قرار برچسب با خود اده آن را به بردار onehot تبدیل میکنیم سپس 10 پیکسل اول تصویر را با این بردار جایگزین میکنیم چون 10 پیکسل اول عموما صفر هستند تاثیری ندارد. هرچنید این مدل interpretabel نیست به دلیل تعریف کردین این داده ها ولی دقت خوبی دارد.