In [7]:
import torch.nn as nn
import torch.nn.functional as F
import os
from torch.autograd import Variable
from torchvision.utils import save_image
import torch

In [8]:
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder,self).__init__()
        self.encoder=nn.Sequential(
        nn.Linear(28*28,1024),
        nn.ReLU(),
        nn.Linear(1024,2048),
        nn.ReLU(),
        nn.Linear(2048,4096),
        nn.ReLU()
        )
        self.decoder=nn.Sequential(
        nn.Linear(4096,2048),
        nn.ReLU(),
        nn.Linear(2048,1024),
        nn.ReLU(),
        nn.Linear(1024,28*28),
        nn.ReLU()
        )
    
    def forward(self,x):
        x=self.encoder(x)
        d=x
        x=self.decoder(x)
        return d,x
    
    
import torchvision
import torchvision.transforms as transforms
    
beta1=0.5
learning_rate=0.01
assets_dir = './ass/'
if not os.path.isdir(assets_dir):
    os.mkdir(assets_dir)
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
realimages = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
train_loader= torch.utils.data.DataLoader(realimages, batch_size=100,shuffle=True, num_workers=2)

auto=Autoencoder()
auto
loss_criterion=nn.MSELoss()
Auto_opt = torch.optim.SGD( auto.parameters(), lr=learning_rate, momentum=0.9 )


def to_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 28, 28)
    return x

def kl_divergence(p, q):
    '''
    args:
        2 tensors `p` and `q`
    returns:
        kl divergence between the softmax of `p` and `q`
    '''
    p = F.softmax(p)
    q = F.softmax(q)

    s1 = torch.sum(p * torch.log(p / q))
    s2 = torch.sum((1 - p) * torch.log((1 - p) / (1 - q)))
    return s1 + s2
rho = torch.FloatTensor([0.2 for _ in range(4096)]).unsqueeze(0)

In [None]:
for epoch in range(50):
    for data,_ in train_loader:
        
        Auto_opt.zero_grad()
        data=data.view(-1,28*28)
        acti_dist,output=auto(data)
        
        sparsity=kl_divergence(rho,acti_dist)
        
        loss=loss_criterion(output,data)+0.01*sparsity
        loss.backward()
        
        Auto_opt.step()
        print('epoch [{}/{}], loss:{:.4f}'
          .format(epoch + 1, 30, loss.data[0]))
        
        pic = to_img(output.cpu().data)
        save_image(pic, './ass/image_{}.png'.format(epoch))

    

        
        

  from ipykernel import kernelapp as app


epoch [1/30], loss:0.9410
epoch [1/30], loss:0.9395
epoch [1/30], loss:0.9382
epoch [1/30], loss:0.9409
epoch [1/30], loss:0.9405
epoch [1/30], loss:0.9393
epoch [1/30], loss:0.9385
epoch [1/30], loss:0.9422
epoch [1/30], loss:0.9411
epoch [1/30], loss:0.9392
epoch [1/30], loss:0.9362
epoch [1/30], loss:0.9387
epoch [1/30], loss:0.9359
epoch [1/30], loss:0.9412
epoch [1/30], loss:0.9393
epoch [1/30], loss:0.9392
epoch [1/30], loss:0.9413
epoch [1/30], loss:0.9336
epoch [1/30], loss:0.9363
epoch [1/30], loss:0.9345
epoch [1/30], loss:0.9375
epoch [1/30], loss:0.9366
epoch [1/30], loss:0.9335
epoch [1/30], loss:0.9326
epoch [1/30], loss:0.9358
epoch [1/30], loss:0.9363
epoch [1/30], loss:0.9325
epoch [1/30], loss:0.9351
epoch [1/30], loss:0.9334
epoch [1/30], loss:0.9316
epoch [1/30], loss:0.9315
epoch [1/30], loss:0.9322
epoch [1/30], loss:0.9332
epoch [1/30], loss:0.9327
epoch [1/30], loss:0.9336
epoch [1/30], loss:0.9344
epoch [1/30], loss:0.9339
epoch [1/30], loss:0.9336
epoch [1/30]

epoch [1/30], loss:0.9247
epoch [1/30], loss:0.9281
epoch [1/30], loss:0.9236
epoch [1/30], loss:0.9295
epoch [1/30], loss:0.9242
epoch [1/30], loss:0.9254
epoch [1/30], loss:0.9245
epoch [1/30], loss:0.9282
epoch [1/30], loss:0.9249
epoch [1/30], loss:0.9264
epoch [1/30], loss:0.9254
epoch [1/30], loss:0.9280
epoch [1/30], loss:0.9260
epoch [1/30], loss:0.9248
epoch [1/30], loss:0.9257
epoch [1/30], loss:0.9232
epoch [1/30], loss:0.9253
epoch [1/30], loss:0.9248
epoch [1/30], loss:0.9245
epoch [1/30], loss:0.9252
epoch [1/30], loss:0.9270
epoch [1/30], loss:0.9262
epoch [1/30], loss:0.9280
epoch [1/30], loss:0.9238
epoch [1/30], loss:0.9257
epoch [1/30], loss:0.9268
epoch [1/30], loss:0.9240
epoch [1/30], loss:0.9251
epoch [1/30], loss:0.9283
epoch [1/30], loss:0.9258
epoch [1/30], loss:0.9240
epoch [1/30], loss:0.9272
epoch [1/30], loss:0.9253
epoch [1/30], loss:0.9266
epoch [1/30], loss:0.9194
epoch [1/30], loss:0.9265
epoch [1/30], loss:0.9258
epoch [1/30], loss:0.9261
epoch [1/30]

epoch [2/30], loss:0.9274
epoch [2/30], loss:0.9257
epoch [2/30], loss:0.9260
epoch [2/30], loss:0.9259
epoch [2/30], loss:0.9256
epoch [2/30], loss:0.9277
epoch [2/30], loss:0.9283
epoch [2/30], loss:0.9286
epoch [2/30], loss:0.9222
epoch [2/30], loss:0.9252
epoch [2/30], loss:0.9282
epoch [2/30], loss:0.9217
epoch [2/30], loss:0.9263
epoch [2/30], loss:0.9247
epoch [2/30], loss:0.9226
epoch [2/30], loss:0.9268
epoch [2/30], loss:0.9287
epoch [2/30], loss:0.9249
epoch [2/30], loss:0.9284
epoch [2/30], loss:0.9264
epoch [2/30], loss:0.9263
epoch [2/30], loss:0.9235
epoch [2/30], loss:0.9263
epoch [2/30], loss:0.9242
epoch [2/30], loss:0.9255
epoch [2/30], loss:0.9272
epoch [2/30], loss:0.9251
epoch [2/30], loss:0.9275
epoch [2/30], loss:0.9249
epoch [2/30], loss:0.9279
epoch [2/30], loss:0.9257
epoch [2/30], loss:0.9263
epoch [2/30], loss:0.9271
epoch [2/30], loss:0.9208
epoch [2/30], loss:0.9240
epoch [2/30], loss:0.9290
epoch [2/30], loss:0.9262
epoch [2/30], loss:0.9237
epoch [2/30]

epoch [2/30], loss:0.9275
epoch [2/30], loss:0.9289
epoch [2/30], loss:0.9313
epoch [2/30], loss:0.9268
epoch [2/30], loss:0.9243
epoch [2/30], loss:0.9258
epoch [2/30], loss:0.9256
epoch [2/30], loss:0.9256
epoch [2/30], loss:0.9232
epoch [2/30], loss:0.9271
epoch [2/30], loss:0.9238
epoch [2/30], loss:0.9247
epoch [2/30], loss:0.9273
epoch [2/30], loss:0.9248
epoch [2/30], loss:0.9278
epoch [2/30], loss:0.9254
epoch [2/30], loss:0.9235
epoch [2/30], loss:0.9240
epoch [2/30], loss:0.9247
epoch [2/30], loss:0.9223
epoch [2/30], loss:0.9281
epoch [2/30], loss:0.9265
epoch [2/30], loss:0.9242
epoch [2/30], loss:0.9237
epoch [2/30], loss:0.9240
epoch [2/30], loss:0.9293
epoch [2/30], loss:0.9254
epoch [2/30], loss:0.9269
epoch [2/30], loss:0.9292
epoch [2/30], loss:0.9226
epoch [2/30], loss:0.9277
epoch [2/30], loss:0.9283
epoch [2/30], loss:0.9214
epoch [2/30], loss:0.9254
epoch [2/30], loss:0.9245
epoch [2/30], loss:0.9239
epoch [2/30], loss:0.9285
epoch [2/30], loss:0.9253
epoch [2/30]

epoch [3/30], loss:0.9232
epoch [3/30], loss:0.9261
epoch [3/30], loss:0.9241
epoch [3/30], loss:0.9295
epoch [3/30], loss:0.9261
epoch [3/30], loss:0.9288
epoch [3/30], loss:0.9257
epoch [3/30], loss:0.9263
epoch [3/30], loss:0.9254
epoch [3/30], loss:0.9272
epoch [3/30], loss:0.9225
epoch [3/30], loss:0.9249
epoch [3/30], loss:0.9289
epoch [3/30], loss:0.9258
epoch [3/30], loss:0.9219
epoch [3/30], loss:0.9256
epoch [3/30], loss:0.9240
epoch [3/30], loss:0.9251
epoch [3/30], loss:0.9251
epoch [3/30], loss:0.9264
epoch [3/30], loss:0.9260
epoch [3/30], loss:0.9283
epoch [3/30], loss:0.9282
epoch [3/30], loss:0.9244
epoch [3/30], loss:0.9250
epoch [3/30], loss:0.9248
epoch [3/30], loss:0.9287
epoch [3/30], loss:0.9278
epoch [3/30], loss:0.9283
epoch [3/30], loss:0.9258
epoch [3/30], loss:0.9262
epoch [3/30], loss:0.9259
epoch [3/30], loss:0.9226
epoch [3/30], loss:0.9262
epoch [3/30], loss:0.9230
epoch [3/30], loss:0.9252
epoch [3/30], loss:0.9243
epoch [3/30], loss:0.9238
epoch [3/30]

epoch [3/30], loss:0.9236
epoch [3/30], loss:0.9300
epoch [3/30], loss:0.9262
epoch [3/30], loss:0.9239
epoch [3/30], loss:0.9248
epoch [3/30], loss:0.9255
epoch [3/30], loss:0.9262
epoch [3/30], loss:0.9214
epoch [3/30], loss:0.9222
epoch [3/30], loss:0.9262
epoch [3/30], loss:0.9237
epoch [3/30], loss:0.9268
epoch [3/30], loss:0.9224
epoch [3/30], loss:0.9268
epoch [3/30], loss:0.9255
epoch [3/30], loss:0.9262
epoch [3/30], loss:0.9262
epoch [3/30], loss:0.9254
epoch [3/30], loss:0.9247
epoch [3/30], loss:0.9262
epoch [3/30], loss:0.9263
epoch [3/30], loss:0.9255
epoch [3/30], loss:0.9273
epoch [3/30], loss:0.9270
epoch [3/30], loss:0.9234
epoch [3/30], loss:0.9249
epoch [3/30], loss:0.9243
epoch [3/30], loss:0.9278
epoch [3/30], loss:0.9244
epoch [3/30], loss:0.9225
epoch [3/30], loss:0.9207
epoch [3/30], loss:0.9229
epoch [3/30], loss:0.9244
epoch [3/30], loss:0.9252
epoch [3/30], loss:0.9281
epoch [3/30], loss:0.9267
epoch [3/30], loss:0.9234
epoch [3/30], loss:0.9244
epoch [3/30]

epoch [4/30], loss:0.9262
epoch [4/30], loss:0.9268
epoch [4/30], loss:0.9260
epoch [4/30], loss:0.9265
epoch [4/30], loss:0.9243
epoch [4/30], loss:0.9270
epoch [4/30], loss:0.9237
epoch [4/30], loss:0.9262
epoch [4/30], loss:0.9286
epoch [4/30], loss:0.9244
epoch [4/30], loss:0.9268
epoch [4/30], loss:0.9273
epoch [4/30], loss:0.9228
epoch [4/30], loss:0.9296
epoch [4/30], loss:0.9217
epoch [4/30], loss:0.9240
epoch [4/30], loss:0.9246
epoch [4/30], loss:0.9260
epoch [4/30], loss:0.9243
epoch [4/30], loss:0.9232
epoch [4/30], loss:0.9293
epoch [4/30], loss:0.9274
epoch [4/30], loss:0.9217
epoch [4/30], loss:0.9245
epoch [4/30], loss:0.9275
epoch [4/30], loss:0.9228
epoch [4/30], loss:0.9239
epoch [4/30], loss:0.9240
epoch [4/30], loss:0.9272
epoch [4/30], loss:0.9273
epoch [4/30], loss:0.9244
epoch [4/30], loss:0.9221
epoch [4/30], loss:0.9221
epoch [4/30], loss:0.9235
epoch [4/30], loss:0.9263
epoch [4/30], loss:0.9239
epoch [4/30], loss:0.9258
epoch [4/30], loss:0.9252
epoch [4/30]

epoch [4/30], loss:0.9293
epoch [4/30], loss:0.9236
epoch [4/30], loss:0.9253
epoch [4/30], loss:0.9272
epoch [4/30], loss:0.9248
epoch [4/30], loss:0.9248
epoch [4/30], loss:0.9257
epoch [4/30], loss:0.9221
epoch [4/30], loss:0.9273
epoch [4/30], loss:0.9242
epoch [4/30], loss:0.9249
epoch [4/30], loss:0.9262
epoch [4/30], loss:0.9263
epoch [4/30], loss:0.9236
epoch [4/30], loss:0.9221
epoch [4/30], loss:0.9252
epoch [4/30], loss:0.9266
epoch [4/30], loss:0.9265
epoch [4/30], loss:0.9224
epoch [4/30], loss:0.9247
epoch [4/30], loss:0.9251
epoch [4/30], loss:0.9256
epoch [4/30], loss:0.9250
epoch [4/30], loss:0.9250
epoch [4/30], loss:0.9269
epoch [4/30], loss:0.9209
epoch [4/30], loss:0.9246
epoch [4/30], loss:0.9256
epoch [4/30], loss:0.9243
epoch [4/30], loss:0.9271
epoch [4/30], loss:0.9265
epoch [4/30], loss:0.9273
epoch [4/30], loss:0.9213
epoch [4/30], loss:0.9282
epoch [4/30], loss:0.9265
epoch [4/30], loss:0.9265
epoch [4/30], loss:0.9241
epoch [4/30], loss:0.9292
epoch [4/30]

epoch [5/30], loss:0.9270
epoch [5/30], loss:0.9268
epoch [5/30], loss:0.9249
epoch [5/30], loss:0.9250
epoch [5/30], loss:0.9265
epoch [5/30], loss:0.9271
epoch [5/30], loss:0.9256
epoch [5/30], loss:0.9281
epoch [5/30], loss:0.9245
epoch [5/30], loss:0.9263
epoch [5/30], loss:0.9268
epoch [5/30], loss:0.9248
epoch [5/30], loss:0.9248
epoch [5/30], loss:0.9246
epoch [5/30], loss:0.9263
epoch [5/30], loss:0.9234
epoch [5/30], loss:0.9243
epoch [5/30], loss:0.9250
epoch [5/30], loss:0.9243
epoch [5/30], loss:0.9282
epoch [5/30], loss:0.9276
epoch [5/30], loss:0.9248
epoch [5/30], loss:0.9276
epoch [5/30], loss:0.9266
epoch [5/30], loss:0.9265
epoch [5/30], loss:0.9258
epoch [5/30], loss:0.9259
epoch [5/30], loss:0.9266
epoch [5/30], loss:0.9229
epoch [5/30], loss:0.9251
epoch [5/30], loss:0.9243
epoch [5/30], loss:0.9238
epoch [5/30], loss:0.9258
epoch [5/30], loss:0.9245
epoch [5/30], loss:0.9256
epoch [5/30], loss:0.9236
epoch [5/30], loss:0.9258
epoch [5/30], loss:0.9265
epoch [5/30]

epoch [5/30], loss:0.9246
epoch [5/30], loss:0.9247
epoch [5/30], loss:0.9242
epoch [5/30], loss:0.9276
epoch [5/30], loss:0.9233
epoch [5/30], loss:0.9273
epoch [5/30], loss:0.9247
epoch [5/30], loss:0.9226
epoch [5/30], loss:0.9285
epoch [5/30], loss:0.9248
epoch [5/30], loss:0.9246
epoch [5/30], loss:0.9235
epoch [5/30], loss:0.9270
epoch [5/30], loss:0.9265
epoch [5/30], loss:0.9245
epoch [5/30], loss:0.9232
epoch [5/30], loss:0.9255
epoch [5/30], loss:0.9250
epoch [5/30], loss:0.9217
epoch [5/30], loss:0.9259
epoch [5/30], loss:0.9246
epoch [5/30], loss:0.9267
epoch [5/30], loss:0.9255
epoch [5/30], loss:0.9292
epoch [5/30], loss:0.9253
epoch [5/30], loss:0.9232
epoch [5/30], loss:0.9277
epoch [5/30], loss:0.9248
epoch [5/30], loss:0.9243
epoch [5/30], loss:0.9267
epoch [5/30], loss:0.9226
epoch [5/30], loss:0.9253
epoch [5/30], loss:0.9232
epoch [5/30], loss:0.9270
epoch [5/30], loss:0.9242
epoch [5/30], loss:0.9209
epoch [5/30], loss:0.9269
epoch [5/30], loss:0.9269
epoch [5/30]

epoch [6/30], loss:0.9290
epoch [6/30], loss:0.9268
epoch [6/30], loss:0.9210
epoch [6/30], loss:0.9263
epoch [6/30], loss:0.9269
epoch [6/30], loss:0.9278
epoch [6/30], loss:0.9251
epoch [6/30], loss:0.9249
epoch [6/30], loss:0.9279
epoch [6/30], loss:0.9259
epoch [6/30], loss:0.9253
epoch [6/30], loss:0.9276
epoch [6/30], loss:0.9284
epoch [6/30], loss:0.9270
epoch [6/30], loss:0.9264
epoch [6/30], loss:0.9243
epoch [6/30], loss:0.9278
epoch [6/30], loss:0.9254
epoch [6/30], loss:0.9281
epoch [6/30], loss:0.9258
epoch [6/30], loss:0.9221
epoch [6/30], loss:0.9252
epoch [6/30], loss:0.9255
epoch [6/30], loss:0.9258
epoch [6/30], loss:0.9237
epoch [6/30], loss:0.9259
epoch [6/30], loss:0.9259
epoch [6/30], loss:0.9231
epoch [6/30], loss:0.9266
epoch [6/30], loss:0.9239
epoch [6/30], loss:0.9253
epoch [6/30], loss:0.9229
epoch [6/30], loss:0.9246
epoch [6/30], loss:0.9256
epoch [6/30], loss:0.9247
epoch [6/30], loss:0.9277
epoch [6/30], loss:0.9228
epoch [6/30], loss:0.9267
epoch [6/30]

epoch [6/30], loss:0.9259
epoch [6/30], loss:0.9246
epoch [6/30], loss:0.9250
epoch [6/30], loss:0.9254
epoch [6/30], loss:0.9267
epoch [6/30], loss:0.9259
epoch [6/30], loss:0.9259
epoch [6/30], loss:0.9211
epoch [6/30], loss:0.9275
epoch [6/30], loss:0.9245
epoch [6/30], loss:0.9256
epoch [6/30], loss:0.9265
epoch [6/30], loss:0.9231
epoch [6/30], loss:0.9272
epoch [6/30], loss:0.9256
epoch [6/30], loss:0.9265
epoch [6/30], loss:0.9245
epoch [6/30], loss:0.9298
epoch [6/30], loss:0.9253
epoch [6/30], loss:0.9296
epoch [6/30], loss:0.9245
epoch [6/30], loss:0.9227
epoch [6/30], loss:0.9239
epoch [6/30], loss:0.9235
epoch [6/30], loss:0.9253
epoch [6/30], loss:0.9249
epoch [6/30], loss:0.9257
epoch [6/30], loss:0.9255
epoch [6/30], loss:0.9242
epoch [6/30], loss:0.9245
epoch [6/30], loss:0.9236
epoch [6/30], loss:0.9281
epoch [6/30], loss:0.9283
epoch [6/30], loss:0.9264
epoch [6/30], loss:0.9282
epoch [6/30], loss:0.9230
epoch [6/30], loss:0.9270
epoch [6/30], loss:0.9249
epoch [6/30]

epoch [7/30], loss:0.9296
epoch [7/30], loss:0.9234
epoch [7/30], loss:0.9253
epoch [7/30], loss:0.9263
epoch [7/30], loss:0.9254
epoch [7/30], loss:0.9247
epoch [7/30], loss:0.9254
epoch [7/30], loss:0.9259
epoch [7/30], loss:0.9250
epoch [7/30], loss:0.9264
epoch [7/30], loss:0.9284
epoch [7/30], loss:0.9266
epoch [7/30], loss:0.9257
epoch [7/30], loss:0.9254
epoch [7/30], loss:0.9286
epoch [7/30], loss:0.9226
epoch [7/30], loss:0.9277
epoch [7/30], loss:0.9246
epoch [7/30], loss:0.9262
epoch [7/30], loss:0.9233
epoch [7/30], loss:0.9259
epoch [7/30], loss:0.9241
epoch [7/30], loss:0.9254
epoch [7/30], loss:0.9286
epoch [7/30], loss:0.9264
epoch [7/30], loss:0.9241
epoch [7/30], loss:0.9249
epoch [7/30], loss:0.9263
epoch [7/30], loss:0.9266
epoch [7/30], loss:0.9231
epoch [7/30], loss:0.9240
epoch [7/30], loss:0.9249
epoch [7/30], loss:0.9280
epoch [7/30], loss:0.9277
epoch [7/30], loss:0.9241
epoch [7/30], loss:0.9279
epoch [7/30], loss:0.9250
epoch [7/30], loss:0.9256
epoch [7/30]