In [0]:
# Import Libraries
from models_to_prune import *
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import numpy as np

import torchvision.transforms as tf
import torchvision.datasets as ds
import torch.utils.data as data

import os
import time
import torch

from torch.utils.data import TensorDataset, DataLoader

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# See layer shapes
cnn = BasicCNN()
for name, layer in cnn.named_modules():
    if 'conv' in name:
        filters = layer.weight.data.clone()
        print(name,':',filters.size())
        pooled_filter = torch.squeeze(F.avg_pool2d(filters,
                                                   filters.size()[-1]))
        print("pooled :",pooled_filter.size())

# interpet 4d tensors as set of 3d blocks.  

conv1 : torch.Size([64, 3, 3, 3])
pooled : torch.Size([64, 3])
conv2 : torch.Size([128, 64, 3, 3])
pooled : torch.Size([128, 64])
conv3 : torch.Size([256, 128, 3, 3])
pooled : torch.Size([256, 128])
conv4 : torch.Size([512, 256, 3, 3])
pooled : torch.Size([512, 256])


In [0]:
# Create "dataset" of pooled layers
# Convert set of 3d blocks to set of flat 2d maps. 

# Create pad-tensor container, same size as biggest layer
filter_repeats = 1000 # each filter layer will be repeated this many times
feat_size = 16  # size of 2d maps
state_rep = torch.zeros([filter_repeats, 512, feat_size,feat_size]) # set of N padded [512,16,16] 
                                                        # tensors for each of the 4 layers  
for i in range(filter_repeats):
    cnn = BasicCNN()
    for name, layer in cnn.named_modules():
        if 'conv' in name:
            filters = layer.weight.data.clone()
            pooled_filter = torch.squeeze(F.avg_pool2d(filters,
                                                       filters.size()[-1]))
            conv_layer_num = int(name[-1])
            size = pooled_filter.size()
            #if conv_layer_num == 1:
            #    pads = (feat_size//2) - size[-1]//2
            #    state_rep[i, :size[0], feat_size//2, pads-1 :-pads] = pooled_filter  # copy in center
            #elif conv_layer_num == 2:
            #    pads = (feat_size//2) - 4
            #    state_rep[i+filter_repeats, :size[0], pads:-pads, pads:-pads] = pooled_filter.view(size[0],8,8)
            #elif conv_layer_num == 3:
            #    pads_r = (feat_size//2) - 4
            #    pads_c = (feat_size//2) - 8
            #    state_rep[i+filter_repeats*2, :size[0], :8, :16] = pooled_filter.view(size[0],8,16)
            if conv_layer_num == 4:
                state_rep[i] = pooled_filter.view(size[0],16,16) # same size as init state_rep
                #state_rep[i+filter_repeats*3] = pooled_filter.view(size[0],16,16) # same size as init state_rep
                #print(state_rep[i+filter_repeats*3][0])

val_rep = filter_repeats//10
validation = torch.zeros([val_rep*4, 512, feat_size,feat_size]) # set of N padded [512,16,16] 
                                                        # tensors for each of the 4 layers  
for i in range(val_reps):
    cnn = BasicCNN()
    for name, layer in cnn.named_modules():
        if 'conv' in name:
            filters = layer.weight.data.clone()
            pooled_filter = torch.squeeze(F.avg_pool2d(filters,
                                                       filters.size()[-1]))
            conv_layer_num = int(name[-1])
            size = pooled_filter.size()
            if conv_layer_num == 1:
                pads = (feat_size//2) - size[-1]//2
                validation[i, :size[0], feat_size//2, pads-1 :-pads] = pooled_filter  # copy in center
            elif conv_layer_num == 2:
                pads = (feat_size//2) - 4
                validation[i+val_rep, :size[0], pads:-pads, pads:-pads] = pooled_filter.view(size[0],8,8)
            elif conv_layer_num == 3:
                pads_r = (feat_size//2) - 4
                pads_c = (feat_size//2) - 8
                validation[i+val_rep*2, :size[0], :8, :16] = pooled_filter.view(size[0],8,16)
            elif conv_layer_num == 4:
                validation[i+val_rep*3] = pooled_filter.view(size[0],16,16) # same size as init state_rep

In [0]:
# Build Autoencoder Class, modified from https://github.com/L1aoXingyu

class autoencoder(nn.Module):
    def __init__(self):
        super(autoencoder, self).__init__()
        self.encoder = nn.Sequential( # input size is [512,16,16]
            nn.Conv2d(512, 64, 3),  # b, 64, 14, 14
            nn.ReLU(True),
            nn.MaxPool2d(4, stride=2),  # b, 64, 6, 6
            nn.Conv2d(64, 16, 3),  # b, 16, 4, 4
            nn.ReLU(True),
            nn.MaxPool2d(2, stride=1),  # b, 16, 3, 3
            nn.Flatten(), #from dim=1 to -1
            nn.Linear(16*3*3,100)
        )
        
        self.latent_to_map = nn.Linear(100, 16*3*3)
        self.decoder = nn.Sequential(    
            nn.ConvTranspose2d(16, 64, 3, stride=1),  # b, 64,4,4 
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 256, 5, stride=2),  # b, 256, 10, 10
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 512, 4),  # b, 512, 16, 16
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.latent_to_map(x).view(-1,16,3,3) 
        x = self.decoder(x)
        return x


In [53]:
num_epochs = 10
batch_size = 32
learning_rate = 1e-3

state_rep = state_rep.to(device)
train_dl = DataLoader(state_rep, batch_size=batch_size, shuffle=True)
valid_dl = DataLoader(validation, batch_size=batch_size, shuffle=True)

model = autoencoder().cuda()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
                             weight_decay=1e-5)

for epoch in range(num_epochs):
    for i, data in enumerate(train_dl):
        model.train()
        data = Variable(data).cuda()
        # ===================forward=====================
        output = model(data)
        loss = criterion(output, data)
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
         # ===================log========================
        if i % 1 == 0:
            model.eval()
            val_loss = 0
            num_batches = 0
            ave_val_loss = 0
            with torch.no_grad():
                for val in valid_dl:
                    val = Variable(val).cuda()
                    val_output = model(val)
                    val_loss += criterion(val_output,val)
                    num_batches += 1
                ave_val_loss = val_loss/num_batches


            print('epoch [{}/{}], loss:{}, val_loss:{}'
                .format(epoch+1, num_epochs, loss.item(), ave_val_loss))

torch.save(model.state_dict(), './conv_autoencoder.pth')

epoch [1/10], loss:5.780967330792919e-05, val_loss:1.929750396811869e-05
epoch [1/10], loss:2.7832160412799567e-05, val_loss:1.26446939248126e-05
epoch [1/10], loss:2.1137650037417188e-05, val_loss:1.1772473953897133e-05
epoch [1/10], loss:2.022811895585619e-05, val_loss:1.1493663805595133e-05
epoch [1/10], loss:2.005392343562562e-05, val_loss:1.0184913662669715e-05
epoch [1/10], loss:1.872719258244615e-05, val_loss:9.423513802175876e-06
epoch [1/10], loss:1.8009019186138175e-05, val_loss:9.48896831687307e-06
epoch [1/10], loss:1.8074239051202312e-05, val_loss:9.104628588829655e-06
epoch [1/10], loss:1.772091127349995e-05, val_loss:8.566620635974687e-06
epoch [1/10], loss:1.712510857032612e-05, val_loss:8.14661143522244e-06
epoch [1/10], loss:1.6748643247410655e-05, val_loss:8.111560418910813e-06


KeyboardInterrupt: ignored

In [19]:
Win = 10 
stride = 1
kernel = 8
(Win-1)*stride + kernel-1

16

In [7]:
4000/16

250.0

In [15]:
print(validation.size())


torch.Size([400, 512, 16, 16])
