In [None]:
import torch
from model.Backbone import Resnet
from model.Head import Fullyconnectedhead
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import numpy as np
from torchvision.transforms import Compose, RandomHorizontalFlip, RandomGrayscale, ToTensor, Normalize
from torchsummary import summary
from torch.optim import Adam
from torch.nn import MSELoss,  Sequential
from tqdm import tqdm

In [None]:
TrainBS = 64
TestBS = 64
Learning_Rate = 0.001
Random_Seed = np.random.uniform()
torch.manual_seed(Random_Seed)

In [None]:
Train_Data = DataLoader(dataset = CIFAR10(train = True,
                                          root = '/data/',
                                          download = True,
                                          transform = Compose([RandomHorizontalFlip(),
                                                               RandomGrayscale(),
                                                               ToTensor(),
                                                               Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])),
                       batch_size = TrainBS,
                       shuffle = True)
Test_Data = DataLoader(dataset = CIFAR10(train = False,
                                         root = '/data/',
                                         download = True,
                                         transform = Compose([ToTensor(),
                                                              Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])),
                       batch_size = TestBS,
                       shuffle = True)

In [None]:
Network = Sequential(Resnet(34),Fullyconnectedhead)
summary(Network, (3, 224, 224))

In [None]:
Teacher_Network = Resnet(34)
optimizer = Adam(params = Teacher_Network.parameters(), lr = Learning_Rate)  
Loss_Function = MSELoss()

In [None]:
epochs =  10
for epoch in range(1,epochs + 1):
    Teacher_Network.train()
    correct = 0
    avg_loss = 0
    for batch_idx, (data, target) in tqdm(enumerate(Train_Data)):
        target = target.type(torch.float32)
        optimizer.zero_grad()
        output = Teacher_Network(data)
        output = output.view(-1,10)
        output = torch.argmax(output, 1)
        output = output.type(torch.float32)
        output = torch.tensor(output, requires_grad = True)
        loss = Loss_Function (output, target)
        loss.backward()
        optimizer.step()
        avg_loss += loss
        correct += target.eq(output.data).sum()
    print('Train_Epoch:{}\t Loss:{:.6f}\t Acc:{:.1f}'.format(
        epoch, avg_loss/len(Train_Data), 100.*correct/len(Train_Data.dataset)))

In [None]:
torch.save(Network.state_dict(), 'KD-resnet.pth')