<a href="https://colab.research.google.com/github/ycaxgjd/dd2424/blob/master/code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import copy
import pickle
import time
from datetime import datetime
from io import BytesIO

import matplotlib.pyplot as plt
import numpy as np
import requests
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as T
from PIL import Image
from google.colab import files
from torchvision import datasets
from torchvision import transforms
from torchvision.models import resnet
from torchvision.models import resnet18
from torchvision.models import resnet34
from torchvision.models import vgg11
from torchvision.models import vgg19

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

if device == torch.device('cuda'):
    cudnn.benchmark = True


class VGG(nn.Module):
    cfg = {
        11: [[64, 'M'], [128, 'M'], [256, 256, 'M'], [512, 512, 'M'], [512, 512, 'M']],
        19: [[64, 64, 'M'], [128, 128, 'M'], [256, 256, 256, 256, 'M'], [512, 512, 512, 512, 'M'],
             [512, 512, 512, 512, 'M']],
    }
    cfg_pm = {
        11: vgg11(pretrained=True),
        19: vgg19(pretrained=True),
    }

    def __init__(self, key, pretrained=False):
        super(VGG, self).__init__()

        self.pretrained = pretrained
        self.grams = None

        self.in_channels = 3
        if pretrained:
            self._pm = self.cfg_pm[key]
            self._features1 = self._make_layers(self.cfg[key][0])
            self._features2 = self._make_layers(self.cfg[key][1])
            self._features3 = self._make_layers(self.cfg[key][2])
            self._features4 = self._make_layers(self.cfg[key][3])
            self._features5 = self._make_layers(self.cfg[key][4])
            self._pm.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
            self._pm.classifier = nn.Sequential(
                nn.Linear(512, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 10),
            )
            for _para in list(self._pm.parameters()):
                _para.requires_grad = False
        else:
            self.features1 = self._make_layers(self.cfg[key][0])
            self.features2 = self._make_layers(self.cfg[key][1])
            self.features3 = self._make_layers(self.cfg[key][2])
            self.features4 = self._make_layers(self.cfg[key][3])
            self.features5 = self._make_layers(self.cfg[key][4])
            self.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
            self.classifier = nn.Sequential(
                nn.Linear(512, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 4096),
                nn.ReLU(True),
                nn.Dropout(),
                nn.Linear(4096, 10),
            )

    def _make_layers(self, cfg):
        layers = []
        for v in cfg:
            if v == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                conv2d = nn.Conv2d(self.in_channels, v, kernel_size=3, padding=1)
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                self.in_channels = v
        return nn.Sequential(*layers)

    def forward(self, x):
        self.grams = list()
        if self.pretrained:
            x = self._features1(x)
            self.grams.append(x)
            x = self._features2(x)
            self.grams.append(x)
            x = self._features3(x)
            self.grams.append(x)
            x = self._features4(x)
            self.grams.append(x)
            x = self._features5(x)
            self.grams.append(x)
            x = self._pm.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self._pm.classifier(x)
        else:
            x = self.features1(x)
            self.grams.append(x)
            x = self.features2(x)
            self.grams.append(x)
            x = self.features3(x)
            self.grams.append(x)
            x = self.features4(x)
            self.grams.append(x)
            x = self.features5(x)
            self.grams.append(x)
            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.classifier(x)
        return x


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1):
        super(BasicBlock, self).__init__()

        self.conv1 = nn.Conv2d(inplanes, planes, stride=stride, kernel_size=3, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, stride=1, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or inplanes != planes * self.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inplanes, planes * self.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * self.expansion),
            )

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(x)
        out = self.relu(out)

        return out


class ResNet(nn.Module):
    cfg = {
        18: [2, 2, 2, 2],
        34: [3, 4, 6, 3],
    }
    cfg_pm = {
        18: resnet18(pretrained=True),
        34: resnet34(pretrained=True),
    }

    def __init__(self, key, pretrained=False):
        super(ResNet, self).__init__()

        self.pretrained = pretrained
        self.grams = None
        if pretrained:
            self._pm = self.cfg_pm[key]
            self._pm.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            self._pm.fc = nn.Linear(self._pm.fc.in_features, 10)
            for _para in list(self._pm.parameters()):
                _para.requires_grad = False
        else:
            self.inplanes = 64
            self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu = nn.ReLU(inplace=True)
            self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
            self.layer1 = self._make_layer(64, self.cfg[key][0], stride=1)
            self.layer2 = self._make_layer(128, self.cfg[key][1], stride=2)
            self.layer3 = self._make_layer(256, self.cfg[key][2], stride=2)
            self.layer4 = self._make_layer(512, self.cfg[key][3], stride=2)
            self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
            self.fc = nn.Linear(512 * BasicBlock.expansion, 10)

    def _make_layer(self, planes, blocks, stride=1):
        layers = list()
        layers.append(BasicBlock(self.inplanes, planes, stride))
        self.inplanes = planes * BasicBlock.expansion
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        self.grams = list()
        if self.pretrained:
            x = self._pm.conv1(x)
            x = self._pm.bn1(x)
            x = self._pm.relu(x)
            # x = self._pm.maxpool(x)

            x = self._pm.layer1(x)
            self.grams.append(x)
            x = self._pm.layer2(x)
            self.grams.append(x)
            x = self._pm.layer3(x)
            self.grams.append(x)
            x = self._pm.layer4(x)
            self.grams.append(x)

            x = self._pm.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self._pm.fc(x)
        else:
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.relu(x)
            #             x = self.maxpool(x)

            x = self.layer1(x)
            self.grams.append(x)
            x = self.layer2(x)
            self.grams.append(x)
            x = self.layer3(x)
            self.grams.append(x)
            x = self.layer4(x)
            self.grams.append(x)

            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            x = self.fc(x)
        return x


class Teacher(object):
    def __init__(self, pretrained=False, teacher_id='resnet34'):
        self.pretrained = pretrained
        self.model_id = teacher_id
        self.model = None
        self.epochs = 200
        self.train_batch_size = 128
        self.test_batch_size = 128
        self.train_loader = None
        self.test_loader = None
        self.criterion = None
        self.optimizer = None
        self.scheduler = None

    def load_data(self):
        train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
        test_transform = transforms.Compose([transforms.ToTensor()])
        train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size,
                                                        shuffle=True)
        test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
        self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)

    def load_model(self):
        if self.model_id == 'vgg19':
            self.model = VGG(19, self.pretrained).to(device)
            self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        elif self.model_id == 'resnet34':
            self.model = ResNet(34, self.pretrained).to(device)
            self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

        # self.model = torch.nn.DataParallel(self.model)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[75, 150], gamma=0.5)
        self.criterion = nn.CrossEntropyLoss().to(device)

    def train(self):
        self.model.train()
        train_loss = 0
        train_correct = 0
        total = 0

        last_time = time.time()
        for batch_num, (data, target) in enumerate(self.train_loader):
            data, target = data.to(device), target.to(device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            prediction = torch.max(output, 1)  # second param '1' represents the dimension to be reduced
            total += target.size(0)
            # train_correct incremented by one if predicted right
            train_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())

        acc_msg = train_correct / total
        loss_msg = train_loss / self.train_batch_size
        time_msg = time.time() - last_time
        return acc_msg, loss_msg, time_msg

    def test(self):
        self.model.eval()
        test_loss = 0
        test_correct = 0
        total = 0

        last_time = time.time()
        with torch.no_grad():
            for _, (data, target) in enumerate(self.test_loader):
                data, target = data.to(device), target.to(device)
                output = self.model(data)
                loss = self.criterion(output, target)
                test_loss += loss.item()
                prediction = torch.max(output, 1)
                total += target.size(0)
                test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())

        acc_msg = test_correct / total
        loss_msg = test_loss / self.test_batch_size
        time_msg = time.time() - last_time
        return acc_msg, loss_msg, time_msg

    def run(self):
        self.load_data()
        self.load_model()

        try:
            with open(f'sd4{self.model_id}.pkl', 'rb') as file:
                best_settings = pickle.load(file)
            self.model.load_state_dict(best_settings)
            return
        except IOError:
            pass
            # uploaded = files.upload()
            # for file_name in uploaded.keys():
            #     if file_name == f'sd4{self.model_id}.pkl':
            #         with open(file_name, 'rb') as file:
            #             best_settings = pickle.load(file)
            #         self.model.load_state_dict(best_settings)
            #         return

        best_acc = 0
        best_settings = None
        init_time = time.time()
        for epoch in range(self.epochs):
            self.scheduler.step(epoch)
            print('==> Epoch %d/200' % (epoch + 1))

            train_acc, train_loss, train_time = self.train()
            acc_msg = 'TrainAcc %.3f%%' % (train_acc * 100.0)
            loss_msg = 'TrainLoss %.3f' % train_loss
            time_msg = datetime.fromtimestamp(train_time).strftime('%H:%M:%S')
            print('Train:' + ' | '.join([acc_msg, loss_msg, time_msg]))

            test_acc, test_loss, test_time = self.test()
            acc_msg = 'TestAcc %.3f%%' % (test_acc * 100.0)
            loss_msg = 'TestLoss %.3f' % test_loss
            time_msg = datetime.fromtimestamp(test_time).strftime('%H:%M:%S')
            print('Test:' + ' | '.join([acc_msg, loss_msg, time_msg]))

            if best_acc < test_acc:
                best_acc = test_acc
                best_settings = copy.deepcopy(self.model.state_dict())
        total_time = time.time() - init_time
        total_time_msg = datetime.fromtimestamp(total_time).strftime('%H:%M:%S')
        print('==> Best TestAcc %.3f%%' % (best_acc * 100))
        print('==> Total Time ' + total_time_msg)
        self.model.load_state_dict(best_settings)

        with open(f'sd4{self.model_id}.pkl', 'wb') as file:
            pickle.dump(best_settings, file)
        # files.download(f'sd4{self.model_id}.pkl')


class Student(object):
    def __init__(self, distill=False, attention=False, student_id='resnet18', teacher_id='resnet34'):
        self.distill = distill
        self.attention = attention
        self.teacher_id = teacher_id
        self.teacher = Teacher(pretrained=True, teacher_id=teacher_id) if distill else None
        self.model_id = student_id
        self.model = None
        self.epochs = 200
        self.train_batch_size = 128
        self.test_batch_size = 128
        self.train_loader = None
        self.test_loader = None
        self.criterion = None
        self.optimizer = None
        self.scheduler = None

    def load_data(self):
        train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])
        test_transform = transforms.Compose([transforms.ToTensor()])
        train_set = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)
        self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size,
                                                        shuffle=True)
        test_set = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
        self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)

    def load_model(self):
        if self.distill:
            self.teacher.run()
            for para in list(self.teacher.model.parameters()):
                para.requires_grad = False

        if self.model_id == 'vgg11':
            self.model = VGG(11).to(device)
            self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
        elif self.model_id == 'resnet18':
            self.model = ResNet(18).to(device)
            self.optimizer = optim.SGD(self.model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)

        # self.model = torch.nn.DataParallel(self.model)
        self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[75, 150], gamma=0.5)
        self.criterion = nn.CrossEntropyLoss().to(device)

    def train(self):
        self.model.train()
        train_loss = 0
        train_correct = 0
        total = 0
        if self.distill:
            temperature = 0.1
            alpha = 1.0
        if self.attention:
            beta = 0.1

        last_time = time.time()
        for batch_num, (data, target) in enumerate(self.train_loader):
            data, target = data.to(device), target.to(device)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)

            if self.distill:
                self.teacher.optimizer.zero_grad()
                teacher_output = self.teacher.model(data)
                kl_loss = F.kl_div(F.log_softmax(output / temperature),
                                   F.softmax(teacher_output / temperature),
                                   reduction='batchmean')
                ce_loss = F.cross_entropy(output, target)
                loss += alpha * (temperature ** 2) * kl_loss + (1 - alpha) * ce_loss
            if self.attention:
                at_loss = 0
                for s_gram, t_gram in zip(self.model.grams, self.teacher.model.grams):
                    s_at = F.normalize(s_gram.pow(2).mean(1).view(s_gram.size(0), -1))
                    t_at = F.normalize(t_gram.pow(2).mean(1).view(t_gram.size(0), -1))
                    at_loss += (s_at - t_at).pow(2).mean().sum()
                loss += beta * at_loss

            loss.backward()
            self.optimizer.step()
            train_loss += loss.item()
            prediction = torch.max(output, 1)  # second param '1' represents the dimension to be reduced
            total += target.size(0)
            # train_correct incremented by one if predicted right
            train_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())

        acc_msg = train_correct / total
        loss_msg = train_loss / self.train_batch_size
        time_msg = time.time() - last_time
        return acc_msg, loss_msg, time_msg

    def test(self):
        self.model.eval()
        test_loss = 0
        test_correct = 0
        total = 0

        last_time = time.time()
        with torch.no_grad():
            for _, (data, target) in enumerate(self.test_loader):
                data, target = data.to(device), target.to(device)
                output = self.model(data)
                loss = self.criterion(output, target)
                test_loss += loss.item()
                prediction = torch.max(output, 1)
                total += target.size(0)
                test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())

        acc_msg = test_correct / total
        loss_msg = test_loss / self.test_batch_size
        time_msg = time.time() - last_time
        return acc_msg, loss_msg, time_msg

    def run(self):
        self.load_data()
        self.load_model()
        best_acc = 0
        best_model = None
        best_settings = None
        init_time = time.time()

        for epoch in range(self.epochs):
            self.scheduler.step(epoch)
            print('==> Epoch %d/200' % (epoch + 1))

            train_acc, train_loss, train_time = self.train()
            acc_msg = 'TrainAcc %.3f%%' % (train_acc * 100.0)
            loss_msg = 'TrainLoss %.3f' % train_loss
            time_msg = datetime.fromtimestamp(train_time).strftime('%H:%M:%S')
            print('Train:' + ' | '.join([acc_msg, loss_msg, time_msg]))

            test_acc, test_loss, test_time = self.test()
            acc_msg = 'TestAcc %.3f%%' % (test_acc * 100.0)
            loss_msg = 'TestLoss %.3f' % test_loss
            time_msg = datetime.fromtimestamp(test_time).strftime('%H:%M:%S')
            print('Test:' + ' | '.join([acc_msg, loss_msg, time_msg]))

            if best_acc < test_acc:
                best_acc = test_acc
                best_model = self.model
                best_settings = copy.deepcopy(self.model.state_dict())
        total_time = time.time() - init_time
        total_time_msg = datetime.fromtimestamp(total_time).strftime('%H:%M:%S')
        print('==> Best TestAcc %.3f%%' % (best_acc * 100))
        print('==> Total Time ' + total_time_msg)
        best_model.load_state_dict(best_settings)

        Visualization.feature_map({
            self.model_id: best_model,
            self.teacher_id: self.teacher.model,
        })


class ResNet34AT(resnet.ResNet):
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        g0 = self.layer1(x)
        g1 = self.layer2(g0)
        g2 = self.layer3(g1)
        g3 = self.layer4(g2)

        return [g.pow(2).mean(1) for g in (g0, g1, g2, g3)]


class Visualization:
    @staticmethod
    def _resnet34():
        model = ResNet34AT(resnet.BasicBlock, [3, 4, 6, 3])
        base_model = resnet34(pretrained=True)
        model.load_state_dict(base_model.state_dict())
        return model

    @staticmethod
    def feature_map(models=None):
        urls = {
            'bird': 'https://cdn.pixabay.com/photo/2019/05/10/19/48/bird-4194340_960_720.jpg',
            'stork': 'https://cdn.pixabay.com/photo/2019/05/08/01/05/stork-4187520_960_720.jpg',
        }
        is_demo = False
        if models is None:
            models = {
                'resnet34': Visualization._resnet34(),
            }
            is_demo = True
        # downloads = list()
        for pic_id, pic_url in urls.items():
            response = requests.get(pic_url)
            im = np.ascontiguousarray(Image.open(BytesIO(response.content)), dtype=np.uint8)

            title = f'{pic_id}'
            plt.imshow(im)
            plt.title(title)
            plt.savefig(f'{title}.jpg', dpi=600)
            plt.show()
            # downloads.append(f'{title}.jpg')

            tr_center_crop = T.Compose([
                T.ToPILImage(),
                T.Resize(256),
                T.ToTensor(),
                T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])

            for model_id, model in models.items():
                model.eval()
                model.to(torch.device('cpu'))
                with torch.no_grad():
                    if is_demo:
                        x = tr_center_crop(im).unsqueeze(0)
                        gs = model(x)
                    else:
                        x = tr_center_crop(im).unsqueeze(0)
                        _ = model(x)
                        gs = [g.pow(2).mean(1) for g in model.grams]

                for i, g in enumerate(gs):
                    title = f'{model_id}-{pic_id}-g{i}'
                    plt.imshow(g[0], interpolation='bicubic')
                    plt.title(title)
                    plt.savefig(f'{title}.jpg', dpi=600)
                    plt.show()
                    # downloads.append(f'{title}.jpg')
        # for download in downloads:
        #     files.download(download)


if __name__ == '__main__':
    # basic teachers
    Teacher().run()
    # Teacher(teacher_id='vgg19').run()
    # basic students
    # Student().run()
    # Student(student_id='vgg11').run()
    # distillation
    # Student(distill=True).run()
    # Student(distill=True, student_id='vgg11', teacher_id='vgg19').run()
    # distillation with attention
    # Student(distill=True, attention=True).run()
    # Student(distill=True, attention=True, student_id='vgg11', teacher_id='vgg19').run()
    # Demo Attention Maps Visualization
    # Visualization.feature_map()
