In [22]:
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.optim import Adam
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader
from torchvision import transforms
import random
from scipy.signal import convolve2d
import numpy as np

In [24]:
# define model
def mask_generator():
    random_index = np.random.randint(5,10)
    random_image = np.random.randint(2, size=(1,28,28)).squeeze().astype(np.float32)
    filter = np.array([[1, 2, 1], [2, 4, 2], [1, 2, 1]]) / 16
    for i in range(random_index):
        random_image = convolve2d(random_image, filter, mode='same', boundary='symm')
    mask = (random_image > 0.5).astype(np.float32)
    return mask

def negative_data_generator(image_batch):
    idxs = torch.randperm(image_batch.shape[0])
    x = image_batch[idxs]
    mask = torch.from_numpy(mask_generator())
    hybrid_images = x * mask + (1-mask) * x
    return hybrid_images

class FFLayer(nn.Module):
  def __init__(self,input_dim,output_dim,threshold=2.0,epochs=100):
      super().__init__()
      self.lin = nn.Linear(input_dim,output_dim)
      self.relu = nn.ReLU()
      self.optim = torch.optim.Adam(self.parameters(),lr=0.03)
      self.epochs = epochs
      self.threshold = threshold

  def forward(self,x):
    # x has shape B x D
    x = x / (x.norm(p=2,dim=1,keepdim=True)+1e-4)
    x = self.lin(x)
    return self.relu(x)

  def forward_forward(self,x_pos,x_neg):
    # x has shape B x h x w
    for epoch in range(self.epochs):
      h_pos = self.forward(x_pos)
      h_neg = self.forward(x_neg)
      good_pos = (self.threshold-h_pos**2).mean(dim=1)
      good_neg = (h_neg**2-self.threshold).mean(dim=1)
      self.optim.zero_grad()
      loss = torch.log(1+torch.exp(torch.cat([good_pos,good_neg]))).mean()
      loss.backward()
      self.optim.step()

    return self.forward(x_pos).detach(), self.forward(x_neg).detach()


class FF(nn.Module):
  def __init__(self,input_dim, hidden_dims,device,threshold=2.0,epochs=100):
      super().__init__()
      self.epochs = epochs
      self.threshold = threshold
      self.layers = []
      self.layers.append(FFLayer(input_dim,hidden_dims[0],threshold,epochs).to(device))
      for i in range(1,len(hidden_dims)):
        self.layers.append(FFLayer(hidden_dims[i-1],hidden_dims[i],threshold,epochs).to(device))

  def train(self,x_pos, x_neg):
    h_pos = x_pos.reshape((x_pos.shape[0],-1))
    h_neg = x_neg.reshape((x_pos.shape[0],-1))
    for layer in self.layers:
      h_pos, h_neg = layer.forward_forward(h_pos,h_neg)

  def test(self,x_pos,x_neg):
    with torch.no_grad():
      h_pos = x_pos.reshape((x_pos.shape[0],-1))
      h_neg = x_neg.reshape((x_pos.shape[0],-1))
      losses = []
      for layer in self.layers:
        h_pos = layer.forward(h_pos)
        h_neg = layer.forward(h_neg)
        good_pos = (self.threshold-h_pos**2).mean(dim=1)
        good_neg = (h_neg**2-self.threshold).mean(dim=1)
        loss = torch.log(1+torch.exp(torch.cat([good_pos,good_neg]))).mean()
        losses.append(loss.item())

      final_loss = np.array(losses).mean()
      return final_loss

  def inference(self,x):
    with torch.no_grad():
      x = x.reshape((x.shape[0],-1))
      for layer in self.layers:
        x = layer.forward(x)
      return x

In [25]:
# load dataset
transform = Compose([ToTensor()])
trainset = MNIST('.',train=True,download=True,transform=transform)
testset = MNIST('.',train=False, download=True,transform=transform)

batch_size = 1024
trainLoader = DataLoader(trainset,batch_size,shuffle=True)
testLoader = DataLoader(testset,batch_size,shuffle=True)

In [26]:
#train FF
input_dim = 784
hidden_dims = [512,512,512,512]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FF(input_dim, hidden_dims,device,epochs=100).to(device)

for x_pos,_ in tqdm(trainLoader):
  x_neg =  negative_data_generator(x_pos)
  x_neg = x_neg.to(device)
  x_pos = x_pos.to(device)
  model.train(x_pos,x_neg)

100%|██████████| 59/59 [00:52<00:00,  1.12it/s]


In [18]:
# test FF
testLoss = 0.0
for x_pos,_ in tqdm(testLoader):
  x_neg =  negative_data_generator(x_pos)
  x_neg = x_neg.to(device)
  x_pos = x_pos.to(device)
  testLoss += model.test(x_pos,x_neg)
testLoss = testLoss / len(testLoader)

trainLoss = 0.0
for x_pos,_ in tqdm(trainLoader):
  x_neg =  negative_data_generator(x_pos)
  x_neg = x_neg.to(device)
  x_pos = x_pos.to(device)
  trainLoss += model.test(x_pos,x_neg)
trainLoss = trainLoss / len(trainLoader)
print(f'train loss: {trainLoss}')
print(f'test loss: {testLoss}')

100%|██████████| 10/10 [00:01<00:00,  6.51it/s]
100%|██████████| 59/59 [00:06<00:00,  8.62it/s]

train loss: 0.6931644491219925
test loss: 0.693164199590683





In [27]:
torch.save(model.state_dict(), './FF.pth')

In [31]:
# define classifier
class LinearClassifier(nn.Module):
  def __init__(self,input_dim, num_classes, hidden_dims,device,threshold=2.0,epochs=100):
     super().__init__()
     self.ff = FF(input_dim, hidden_dims,device,threshold=2.0,epochs=100).to(device)
     self.ff.load_state_dict(torch.load('./FF.pth'))
     self.lin = nn.Linear(hidden_dims[-1],num_classes)

  def forward(self,x):
    x = self.ff.inference(x)
    x = self.lin(x)
    return x

با دیتای منفی و هزینه که شبکه آموزش داده شده، در عمل یاد میگیرد که ساختار ارقام را کد کند. جون تصویر منفی ساختار نامنظمی دارد درنتیجه با کم کردن حساسایت روی آن ها و افزایش حاسایت روی تصویر واقعی یاد میگیرد که  که اعداد تک رقمی چگونه نوشته میشوند. سپس با این بردار ویژگی آنها را طبقه بندی میکند.

چون هدف این بوده که
طبقه بند خطی باشد چون مان یک بردار ویژگی یاد گرفتیم و امنتظار داریم که به خوبی ارقام را تشخیص ندهد یعنی بتواند با مرزهای خطی ارقام را جدا کند. اگر بخواهیم از لایه های غیر خطی استفاده کنیم نمی توان کارایی روش غیر نظارتی را سنجید.








پیاده سازی کلی مانند قسمت قبل است فقط اینجا با آموزش شبکه FF ان را در حالت eval قرار داده و تابع خطی را آزمایش میکنیم.

In [34]:
# train classifier
input_dim = 784
num_classes = 10
hidden_dims = [512,512,512,512]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classifier = LinearClassifier(input_dim,num_classes,hidden_dims,device).to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(classifier.parameters(),lr=0.03)

losses = []
for epoch in range(100):
  current_loss = 0.0
  for image,target in tqdm(trainLoader):
    image = image.to(device)
    target = target.to(device)
    logits = classifier(image)
    optimizer.zero_grad()
    loss = criterion(logits,target)
    loss.backward()
    optimizer.step()
    current_loss += loss.item()/len(trainLoader)

  print(f'epoch {epoch}: loss = {current_loss}')



100%|██████████| 59/59 [00:07<00:00,  7.98it/s]


epoch 0: loss = 2.041261945740651


100%|██████████| 59/59 [00:07<00:00,  8.00it/s]


epoch 1: loss = 1.6143200639951027


100%|██████████| 59/59 [00:07<00:00,  7.75it/s]


epoch 2: loss = 1.334612925173872


100%|██████████| 59/59 [00:06<00:00,  8.79it/s]


epoch 3: loss = 1.1506698131561277


100%|██████████| 59/59 [00:07<00:00,  8.07it/s]


epoch 4: loss = 1.0251260116948921


100%|██████████| 59/59 [00:06<00:00,  9.09it/s]


epoch 5: loss = 0.9337073998936155


100%|██████████| 59/59 [00:07<00:00,  8.39it/s]


epoch 6: loss = 0.8657706339480515


100%|██████████| 59/59 [00:06<00:00,  8.67it/s]


epoch 7: loss = 0.8114706247539842


100%|██████████| 59/59 [00:07<00:00,  8.33it/s]


epoch 8: loss = 0.7687552934986048


100%|██████████| 59/59 [00:06<00:00,  8.70it/s]


epoch 9: loss = 0.7327522079823382


100%|██████████| 59/59 [00:06<00:00,  8.54it/s]


epoch 10: loss = 0.7028719520164747


100%|██████████| 59/59 [00:06<00:00,  8.78it/s]


epoch 11: loss = 0.6776516285993285


100%|██████████| 59/59 [00:06<00:00,  8.60it/s]


epoch 12: loss = 0.6552817942732472


100%|██████████| 59/59 [00:06<00:00,  8.45it/s]


epoch 13: loss = 0.6358742410853757


100%|██████████| 59/59 [00:06<00:00,  8.71it/s]


epoch 14: loss = 0.6182057594848891


100%|██████████| 59/59 [00:06<00:00,  8.43it/s]


epoch 15: loss = 0.6033615237575467


100%|██████████| 59/59 [00:06<00:00,  8.71it/s]


epoch 16: loss = 0.5892314355252155


100%|██████████| 59/59 [00:07<00:00,  8.21it/s]


epoch 17: loss = 0.5772546509564934


100%|██████████| 59/59 [00:06<00:00,  8.89it/s]


epoch 18: loss = 0.5652756802106307


100%|██████████| 59/59 [00:07<00:00,  7.71it/s]


epoch 19: loss = 0.5556299681380642


100%|██████████| 59/59 [00:06<00:00,  9.42it/s]


epoch 20: loss = 0.5456192973306623


100%|██████████| 59/59 [00:07<00:00,  8.05it/s]


epoch 21: loss = 0.5375448707806862


100%|██████████| 59/59 [00:06<00:00,  8.77it/s]


epoch 22: loss = 0.5280053135702166


100%|██████████| 59/59 [00:07<00:00,  7.99it/s]


epoch 23: loss = 0.5206291761438725


100%|██████████| 59/59 [00:08<00:00,  7.13it/s]


epoch 24: loss = 0.5141333043575286


100%|██████████| 59/59 [00:10<00:00,  5.46it/s]


epoch 25: loss = 0.5071858985949372


100%|██████████| 59/59 [00:11<00:00,  5.07it/s]


epoch 26: loss = 0.5009856067471584


100%|██████████| 59/59 [00:08<00:00,  6.84it/s]


epoch 27: loss = 0.4953023912542959


100%|██████████| 59/59 [00:07<00:00,  7.53it/s]


epoch 28: loss = 0.4903998627501019


100%|██████████| 59/59 [00:07<00:00,  8.32it/s]


epoch 29: loss = 0.48474173677169663


100%|██████████| 59/59 [00:08<00:00,  7.15it/s]


epoch 30: loss = 0.4799586286989309


100%|██████████| 59/59 [00:09<00:00,  5.92it/s]


epoch 31: loss = 0.4750177153086258


100%|██████████| 59/59 [00:08<00:00,  6.98it/s]


epoch 32: loss = 0.4693046909267619


100%|██████████| 59/59 [00:07<00:00,  7.71it/s]


epoch 33: loss = 0.465481366646492


100%|██████████| 59/59 [00:07<00:00,  7.43it/s]


epoch 34: loss = 0.46137577396328183


100%|██████████| 59/59 [00:08<00:00,  7.15it/s]


epoch 35: loss = 0.4575110997183848


100%|██████████| 59/59 [00:07<00:00,  7.79it/s]


epoch 36: loss = 0.4538111974627284


100%|██████████| 59/59 [00:09<00:00,  6.55it/s]


epoch 37: loss = 0.4502555147065955


100%|██████████| 59/59 [00:08<00:00,  6.77it/s]


epoch 38: loss = 0.44781948348223155


100%|██████████| 59/59 [00:06<00:00,  9.09it/s]


epoch 39: loss = 0.44379192035076975


100%|██████████| 59/59 [00:08<00:00,  6.72it/s]


epoch 40: loss = 0.44140335967985256


100%|██████████| 59/59 [00:08<00:00,  6.58it/s]


epoch 41: loss = 0.43722896959822055


100%|██████████| 59/59 [00:07<00:00,  7.79it/s]


epoch 42: loss = 0.43427862605806117


100%|██████████| 59/59 [00:07<00:00,  7.74it/s]


epoch 43: loss = 0.4312577732538774


100%|██████████| 59/59 [00:07<00:00,  8.29it/s]


epoch 44: loss = 0.43008724966291645


100%|██████████| 59/59 [00:08<00:00,  6.82it/s]


epoch 45: loss = 0.42623628751706266


100%|██████████| 59/59 [00:07<00:00,  8.33it/s]


epoch 46: loss = 0.42351938708353837


100%|██████████| 59/59 [00:08<00:00,  6.72it/s]


epoch 47: loss = 0.4209214656029719


100%|██████████| 59/59 [00:09<00:00,  6.35it/s]


epoch 48: loss = 0.419063282214989


100%|██████████| 59/59 [00:07<00:00,  7.74it/s]


epoch 49: loss = 0.41733555369457953


100%|██████████| 59/59 [00:08<00:00,  6.68it/s]


epoch 50: loss = 0.4140650443101333


100%|██████████| 59/59 [00:09<00:00,  6.46it/s]


epoch 51: loss = 0.4120960811437187


100%|██████████| 59/59 [00:07<00:00,  7.98it/s]


epoch 52: loss = 0.4102864002777359


100%|██████████| 59/59 [00:08<00:00,  6.57it/s]


epoch 53: loss = 0.40778759980605817


100%|██████████| 59/59 [00:07<00:00,  7.88it/s]


epoch 54: loss = 0.4058827641656844


100%|██████████| 59/59 [00:09<00:00,  6.55it/s]


epoch 55: loss = 0.4046553444054166


100%|██████████| 59/59 [00:08<00:00,  6.67it/s]


epoch 56: loss = 0.40223513669886846


100%|██████████| 59/59 [00:08<00:00,  6.67it/s]


epoch 57: loss = 0.4009721349861662


100%|██████████| 59/59 [00:08<00:00,  7.17it/s]


epoch 58: loss = 0.39861322257478354


100%|██████████| 59/59 [00:06<00:00,  9.47it/s]


epoch 59: loss = 0.39786739470595006


100%|██████████| 59/59 [00:07<00:00,  7.99it/s]


epoch 60: loss = 0.39550309969206987


100%|██████████| 59/59 [00:06<00:00,  9.39it/s]


epoch 61: loss = 0.393477736893347


100%|██████████| 59/59 [00:07<00:00,  8.22it/s]


epoch 62: loss = 0.39216697216033924


100%|██████████| 59/59 [00:06<00:00,  9.01it/s]


epoch 63: loss = 0.3907690204806247


100%|██████████| 59/59 [00:06<00:00,  8.56it/s]


epoch 64: loss = 0.3890613009363917


100%|██████████| 59/59 [00:07<00:00,  7.73it/s]


epoch 65: loss = 0.38743507255942117


100%|██████████| 59/59 [00:07<00:00,  8.10it/s]


epoch 66: loss = 0.3861359832650525


100%|██████████| 59/59 [00:09<00:00,  6.03it/s]


epoch 67: loss = 0.3847824567455357


100%|██████████| 59/59 [00:06<00:00,  8.81it/s]


epoch 68: loss = 0.38381241792339377


100%|██████████| 59/59 [00:08<00:00,  6.72it/s]


epoch 69: loss = 0.38195349604396495


100%|██████████| 59/59 [00:07<00:00,  7.91it/s]


epoch 70: loss = 0.380756554967266


100%|██████████| 59/59 [00:08<00:00,  7.27it/s]


epoch 71: loss = 0.37963384895001434


100%|██████████| 59/59 [00:06<00:00,  9.46it/s]


epoch 72: loss = 0.37772972401926086


100%|██████████| 59/59 [00:07<00:00,  8.21it/s]


epoch 73: loss = 0.3770415166677055


100%|██████████| 59/59 [00:06<00:00,  9.65it/s]


epoch 74: loss = 0.37588274731474397


100%|██████████| 59/59 [00:07<00:00,  7.96it/s]


epoch 75: loss = 0.3740102769964831


100%|██████████| 59/59 [00:06<00:00,  9.57it/s]


epoch 76: loss = 0.37384556360163934


100%|██████████| 59/59 [00:07<00:00,  8.21it/s]


epoch 77: loss = 0.37241362111043114


100%|██████████| 59/59 [00:06<00:00,  9.36it/s]


epoch 78: loss = 0.37175049317085146


100%|██████████| 59/59 [00:07<00:00,  7.98it/s]


epoch 79: loss = 0.3700496751373097


100%|██████████| 59/59 [00:06<00:00,  9.37it/s]


epoch 80: loss = 0.368935231940221


100%|██████████| 59/59 [00:07<00:00,  8.43it/s]


epoch 81: loss = 0.36794444817607685


100%|██████████| 59/59 [00:06<00:00,  9.25it/s]


epoch 82: loss = 0.367724998522613


100%|██████████| 59/59 [00:07<00:00,  8.20it/s]


epoch 83: loss = 0.36647129210375123


100%|██████████| 59/59 [00:06<00:00,  9.30it/s]


epoch 84: loss = 0.36493553953655694


100%|██████████| 59/59 [00:07<00:00,  8.14it/s]


epoch 85: loss = 0.3638018374725924


100%|██████████| 59/59 [00:06<00:00,  9.27it/s]


epoch 86: loss = 0.36270877971487536


100%|██████████| 59/59 [00:07<00:00,  8.33it/s]


epoch 87: loss = 0.3625281645079791


100%|██████████| 59/59 [00:06<00:00,  9.32it/s]


epoch 88: loss = 0.36149643588874303


100%|██████████| 59/59 [00:07<00:00,  8.23it/s]


epoch 89: loss = 0.359965439065028


100%|██████████| 59/59 [00:06<00:00,  9.40it/s]


epoch 90: loss = 0.3588505523689721


100%|██████████| 59/59 [00:07<00:00,  8.31it/s]


epoch 91: loss = 0.3580066849619656


100%|██████████| 59/59 [00:06<00:00,  9.42it/s]


epoch 92: loss = 0.3577702429334998


100%|██████████| 59/59 [00:07<00:00,  8.31it/s]


epoch 93: loss = 0.3564003635261019


100%|██████████| 59/59 [00:06<00:00,  9.52it/s]


epoch 94: loss = 0.355677250078169


100%|██████████| 59/59 [00:07<00:00,  8.08it/s]


epoch 95: loss = 0.3549898196074921


100%|██████████| 59/59 [00:06<00:00,  9.35it/s]


epoch 96: loss = 0.3543677213838544


100%|██████████| 59/59 [00:07<00:00,  8.23it/s]


epoch 97: loss = 0.3534873456267987


100%|██████████| 59/59 [00:06<00:00,  9.27it/s]


epoch 98: loss = 0.35268126150309037


100%|██████████| 59/59 [00:07<00:00,  8.42it/s]

epoch 99: loss = 0.35201816387095697





In [35]:
torch.save(classifier.state_dict(), './classifier.pth')

In [39]:
# test classifier
input_dim = 784
num_classes = 10
hidden_dims = [512,512,512,512]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
classifier = LinearClassifier(input_dim,num_classes,hidden_dims,device).to(device)
classifier.load_state_dict(torch.load('./classifier.pth'))


losses = []

acc_test = 0.0
for image,target in tqdm(testLoader):
  image = image.to(device)
  target = target.to(device)
  logits = classifier(image)
  idxs = logits.argmax(dim=1)
  acc_test += (idxs == target).sum().item()/len(target)

acc_test /= len(testLoader)

acc_train = 0.0
for image,target in tqdm(trainLoader):
  image = image.to(device)
  target = target.to(device)
  logits = classifier(image)
  idxs = logits.argmax(dim=1)
  acc_train += (idxs == target).sum().item()/len(target)

acc_train /= len(trainLoader)

print(f'Train accuracy: {acc_train*100}%')
print(f'Test accuracy: {acc_test*100}%')

100%|██████████| 10/10 [00:00<00:00, 10.09it/s]
100%|██████████| 59/59 [00:06<00:00,  9.16it/s]

Train accuracy: 89.65261903434434%
Test accuracy: 89.83458227040816%



