In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, inchannel, outchannel, stride=1):
        super(ResidualBlock, self).__init__()
        self.left = nn.Sequential(
            nn.Conv2d(inchannel, outchannel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(outchannel),
            nn.ReLU(inplace=True),
            nn.Conv2d(outchannel, outchannel, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(outchannel)
        )
        self.shortcut = nn.Sequential()
        if stride != 1 or inchannel != outchannel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inchannel, outchannel, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(outchannel)
            )

    def forward(self, x):
        out = self.left(x)
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, ResidualBlock, num_classes=10):
        super(ResNet, self).__init__()
        self.inchannel = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.layer1 = self.make_layer(ResidualBlock, 64,  2, stride=1)
        self.layer2 = self.make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self.make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self.make_layer(ResidualBlock, 512, 2, stride=2)
        self.fc = nn.Linear(512, num_classes)

    def make_layer(self, block, channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)   #strides=[1,1]
        layers = []
        for stride in strides:
            layers.append(block(self.inchannel, channels, stride))
            self.inchannel = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out


def ResNet18():

    return ResNet(ResidualBlock)

In [42]:
import numpy as np
from PIL import Image
class Log2Trans(object):
   
   #     Parameters
   #    ----------
   #   img: 2D numpy array
   #         The original image with format of (h, w, c)
   #     power: int
   #         The degree of norm, 6 is used in reference paper
   # 
     
    def __init__(self, data_type = np.float16):
        self.data_type = data_type

    # def __call__(self, img):
    #     """
    #     :param img: PIL): Image 

    #     :return: Normalized image
    #     """
    #     w,h,d = np.shape(img)
    #     img = np.asarray(img)
    #     img_dtype = img.dtype
    #     img = img.astype('uint8')
    #     power = np.zeros((w,h,d))
    #     for dim in range(d):
    #         for i in range(w):
    #             for j in range(h):
    #                 # power[i,j,dim] = approx_log(img[i,j,dim])
    #                 n = img[i,j,dim]
    #                 if n == 0:
    #                     power[i,j,dim] = 1
    #                 else:
    #                     k = np.floor(np.log2(n))
    #                     if n - np.power(2, k) < np.power(2, (k+1)) - n:
    #                         x = (n / np.power(2,k) - 1).astype(self.data_type)
    #                         power[i,j,dim] = k + x
    #                     else:
    #                         x = (1 - n / np.power(2, (k+1))).astype(self.data_type) 
    #                         power[i,j,dim] = k + 1 - x
    #     approx_img = np.power(np.ones((w,h,d))*2, power)
    #     img = approx_img.astype(img_dtype)
    #     return Image.fromarray(img)

    def __call__(self, img):
        w,h,d = np.shape(img)
        img = np.asarray(img)
        img_dtype = img.dtype
        img = img.astype(self.data_type)
        k=np.floor(np.log2(np.where(img!=0,img,1))) # set 0 as power of 1
        under_power = np.power(2,k)
        over_power = np.power(2,(k+1))
        under_est = img - under_power
        over_est = over_power - img
        z=np.where(under_est<over_est,k+(img/under_power - 1).astype(np.float16),img)
        z=np.where(under_est>=over_est,k+1-(1 - img/over_power).astype(np.float16),z)
        recon_img = np.power(np.ones((w,h,d))*2,z)
        img = recon_img.astype(img_dtype)
        return Image.fromarray(img)

    def __repr__(self):
        return self.__class__.__name__+'()'

In [43]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import argparse
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)


EPOCH = 100   
pre_epoch = 0  
BATCH_SIZE = 128      
LR = 0.1        

transforms_list = [Log2Trans(), transforms.RandomCrop(32, padding=4),
              transforms.RandomHorizontalFlip()]
transform_train = transforms.Compose([
    # LogTrans,
    transforms.RandomChoice(transforms_list), 
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) 
trainloader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)   

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
# Cifar-10 labels
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

net = ResNet18().to(device)


criterion = nn.CrossEntropyLoss()  
# optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4) 



cuda
Files already downloaded and verified
Files already downloaded and verified


In [44]:
# traning
if __name__ == "__main__":
    if not os.path.exists('./model/'):
        os.makedirs('./model/')
    best_acc = 85
    print("Start Training, Resnet-18!") 
    with open("acc.txt", "w") as f:
        with open("log.txt", "w")as f2:
            for epoch in range(pre_epoch, EPOCH):
                if epoch <= 80:
                    optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
                elif epoch <= 185:
                    optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
                else:
                    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)
                print('\nEpoch: %d' % (epoch + 1))
                net.train()
                sum_loss = 0.0
                correct = 0.0
                total = 0.0
                for i, data in enumerate(trainloader, 0):
                    # data preparation
                    length = len(trainloader)
                    inputs, labels = data
                    inputs, labels = inputs.to(device), labels.to(device)
                    optimizer.zero_grad()

                    # forward + backward
                    outputs = net(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    optimizer.step()

                    # print loss
                    sum_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += predicted.eq(labels.data).cpu().sum()
                    print('[epoch:%d, iter:%d] Loss: %.03f | Acc: %.3f%% '
                          % (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))
                    f2.write('%03d  %05d |Loss: %.03f | Acc: %.3f%% '
                          % (epoch + 1, (i + 1 + epoch * length), sum_loss / (i + 1), 100. * correct / total))
                    f2.write('\n')
                    f2.flush()

                # test accuracy of each epoch
                print("Waiting Test!")
                with torch.no_grad():
                    correct = 0
                    total = 0
                    for data in testloader:
                        net.eval()
                        images, labels = data
                        images, labels = images.to(device), labels.to(device)
                        outputs = net(images)
                        _, predicted = torch.max(outputs.data, 1)
                        total += labels.size(0)
                        correct += (predicted == labels).sum()
                    print('accuracy：%.3f%%' % (100 * torch.true_divide(correct ,total)))
                    acc = 100. * torch.true_divide(correct , total)
                    # acc.txt
                    print('Saving model......')
                    torch.save(net.state_dict(), '%s/net_%03d.pth' % ('./model/', epoch + 1))
                    f.write("EPOCH=%03d,Accuracy= %.3f%%" % (epoch + 1, acc))
                    f.write('\n')
                    f.flush()
                    # best_acc.txt
                    if acc > best_acc:
                        f3 = open("best_acc.txt", "w")
                        f3.write("EPOCH=%d,best_acc= %.3f%%" % (epoch + 1, acc))
                        f3.close()
                        best_acc = acc
            print("Training Finished, TotalEPOCH=%d" % EPOCH)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[epoch:88, iter:34166] Loss: 0.033 | Acc: 99.061% 
[epoch:88, iter:34167] Loss: 0.034 | Acc: 99.057% 
[epoch:88, iter:34168] Loss: 0.033 | Acc: 99.064% 
[epoch:88, iter:34169] Loss: 0.033 | Acc: 99.070% 
[epoch:88, iter:34170] Loss: 0.033 | Acc: 99.071% 
[epoch:88, iter:34171] Loss: 0.034 | Acc: 99.067% 
[epoch:88, iter:34172] Loss: 0.034 | Acc: 99.068% 
[epoch:88, iter:34173] Loss: 0.034 | Acc: 99.069% 
[epoch:88, iter:34174] Loss: 0.034 | Acc: 99.064% 
[epoch:88, iter:34175] Loss: 0.034 | Acc: 99.061% 
[epoch:88, iter:34176] Loss: 0.034 | Acc: 99.062% 
[epoch:88, iter:34177] Loss: 0.034 | Acc: 99.058% 
[epoch:88, iter:34178] Loss: 0.034 | Acc: 99.054% 
[epoch:88, iter:34179] Loss: 0.034 | Acc: 99.045% 
[epoch:88, iter:34180] Loss: 0.034 | Acc: 99.046% 
[epoch:88, iter:34181] Loss: 0.034 | Acc: 99.033% 
[epoch:88, iter:34182] Loss: 0.035 | Acc: 99.034% 
[epoch:88, iter:34183] Loss: 0.035 | Acc: 99.035% 
[epoch:88, iter:3