In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from  torchvision import transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import torchvision
from torchinfo import summary

In [22]:
torch.manual_seed(2022)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")  #自动选择GPU 没有GPU则自动切换cpu

In [23]:
torch.backends.cudnn.benchmark=True

In [24]:
train_data=torchvision.datasets.MNIST(
    root="./",
    train=True,
    transform=transforms.ToTensor(),
    download=True
)

test_data=torchvision.datasets.MNIST(
    root="./",
    train=False,
    transform=transforms.ToTensor(),
    download=True
)

train_loader=DataLoader(dataset=train_data,batch_size=100,shuffle=True)
test_loader=DataLoader(dataset=test_data,batch_size=100,shuffle=False)

In [25]:
class Teacher_model(nn.Module):
    def __init__(self,num_class=10):
        super(Teacher_model,self).__init__()
        self.classfier=nn.Sequential(
            nn.Linear(784,1200),
            nn.Dropout(0.5),
            nn.ReLU(inplace=True),
            nn.Linear(1200,1200),
            nn.Dropout(0.5),
            nn.ReLU(inplace=True),
            nn.Linear(1200,num_class)
        )
    def forward(self,input):
        input=input.view(-1,784)
        out=self.classfier(input)

        return out

In [26]:
model=Teacher_model().to(device)

In [27]:
summary(model)

Layer (type:depth-idx)                   Param #
Teacher_model                            --
├─Sequential: 1-1                        --
│    └─Linear: 2-1                       942,000
│    └─Dropout: 2-2                      --
│    └─ReLU: 2-3                         --
│    └─Linear: 2-4                       1,441,200
│    └─Dropout: 2-5                      --
│    └─ReLU: 2-6                         --
│    └─Linear: 2-7                       12,010
Total params: 2,395,210
Trainable params: 2,395,210
Non-trainable params: 0

In [28]:
criterion=nn.CrossEntropyLoss()
optimizer=torch.optim.Adam(model.parameters(),lr=1e-4)

In [29]:
epochs_max=6
for epoch in range(epochs_max):
    model.train()

    for data,target in tqdm(train_loader):
        data=data.to(device)
        target=target.to(device)

        output=model(data)
        loss=criterion(output,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    num_correct=0
    num_sample=0

    with torch.no_grad():
        for x,y in test_loader:
            x=x.to(device)
            y=y.to(device)

            pred=model(x)
            pred=pred.max(1).indices
            num_correct+=(pred==y).sum()
            num_sample+=pred.size(0)
        acc=(num_correct/num_sample).item()


    print("epoch:{}, \t Accuracy:{:.4f}".format(epoch+1,acc))


100%|██████████| 600/600 [00:03<00:00, 154.01it/s]
  2%|▎         | 15/600 [00:00<00:04, 146.02it/s]

epoch:1, 	 Accuracy:0.9249


100%|██████████| 600/600 [00:03<00:00, 153.43it/s]
  2%|▎         | 15/600 [00:00<00:04, 141.89it/s]

epoch:2, 	 Accuracy:0.9457


100%|██████████| 600/600 [00:03<00:00, 153.55it/s]
  2%|▎         | 15/600 [00:00<00:03, 148.96it/s]

epoch:3, 	 Accuracy:0.9573


100%|██████████| 600/600 [00:03<00:00, 153.39it/s]
  5%|▌         | 30/600 [00:00<00:03, 146.86it/s]

epoch:4, 	 Accuracy:0.9647


100%|██████████| 600/600 [00:03<00:00, 153.20it/s]
  2%|▎         | 15/600 [00:00<00:04, 141.90it/s]

epoch:5, 	 Accuracy:0.9685


100%|██████████| 600/600 [00:03<00:00, 153.51it/s]


epoch:6, 	 Accuracy:0.9724


In [30]:
class Student_model(nn.Module):
    def __init__(self,num_class=10):
        super(Student_model,self).__init__()
        self.classfier=nn.Sequential(
            nn.Linear(784,20),
            nn.ReLU(inplace=True),
            nn.Linear(20,20),
            nn.ReLU(inplace=True),
            nn.Linear(20,num_class)
        )
    def forward(self,input):
        input=input.view(-1,784)
        out=self.classfier(input)

        return out

In [31]:
student_model=Student_model().to(device)
optimizer=torch.optim.Adam(student_model.parameters(),lr=1e-4)
summary(student_model)

Layer (type:depth-idx)                   Param #
Student_model                            --
├─Sequential: 1-1                        --
│    └─Linear: 2-1                       15,700
│    └─ReLU: 2-2                         --
│    └─Linear: 2-3                       420
│    └─ReLU: 2-4                         --
│    └─Linear: 2-5                       210
Total params: 16,330
Trainable params: 16,330
Non-trainable params: 0

In [32]:
epochs_max=6
for epoch in range(epochs_max):
    student_model.train()

    for data,target in tqdm(train_loader):
        data=data.to(device)
        target=target.to(device)

        output=student_model(data)
        loss=criterion(output,target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    student_model.eval()
    num_correct=0
    num_sample=0

    with torch.no_grad():
        for x,y in test_loader:
            x=x.to(device)
            y=y.to(device)

            pred=student_model(x)
            pred=pred.max(1).indices
            num_correct+=(pred==y).sum()
            num_sample+=pred.size(0)
        acc=(num_correct/num_sample).item()


    print("epoch:{}, \t Accuracy:{:.4f}".format(epoch+1,acc))

100%|██████████| 600/600 [00:04<00:00, 142.64it/s]
  2%|▏         | 14/600 [00:00<00:04, 137.62it/s]

epoch:1, 	 Accuracy:0.6930


100%|██████████| 600/600 [00:03<00:00, 152.97it/s]
  2%|▎         | 15/600 [00:00<00:03, 147.93it/s]

epoch:2, 	 Accuracy:0.8228


100%|██████████| 600/600 [00:04<00:00, 145.88it/s]
  2%|▏         | 14/600 [00:00<00:04, 134.16it/s]

epoch:3, 	 Accuracy:0.8558


100%|██████████| 600/600 [00:04<00:00, 144.48it/s]
  3%|▎         | 16/600 [00:00<00:03, 154.74it/s]

epoch:4, 	 Accuracy:0.8724


100%|██████████| 600/600 [00:03<00:00, 159.95it/s]
  2%|▏         | 14/600 [00:00<00:04, 138.99it/s]

epoch:5, 	 Accuracy:0.8818


100%|██████████| 600/600 [00:04<00:00, 143.91it/s]


epoch:6, 	 Accuracy:0.8886


### Knowldge Distillation way

In [41]:
model.eval()

student_model.train()

temp=9
alpha=0.3
soft_loss=nn.KLDivLoss(reduction="batchmean")
optimizer=torch.optim.Adam(student_model.parameters(),lr=1e-4)

In [42]:
epochs_max=24
for epoch in range(epochs_max):

    for data,target in tqdm(train_loader):
        data=data.to(device)
        target=target.to(device)
        with torch.no_grad():
            teacher_output=model(data)

        output=student_model(data)
        loss_hard=criterion(output,target)

        distillation_loss=soft_loss(F.softmax(output/temp,dim=1) ,F.softmax(teacher_output/temp,dim=1))

        loss=loss_hard*alpha+(1-alpha)*distillation_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.eval()
    num_correct=0
    num_sample=0

    with torch.no_grad():
        for x,y in test_loader:
            x=x.to(device)
            y=y.to(device)

            pred=student_model(x)
            pred=pred.max(1).indices
            num_correct+=(pred==y).sum()
            num_sample+=pred.size(0)
        acc=(num_correct/num_sample).item()


    print("epoch:{}, \t Accuracy:{:.4f}".format(epoch+1,acc))


100%|██████████| 600/600 [00:04<00:00, 148.31it/s]
  2%|▏         | 14/600 [00:00<00:04, 136.29it/s]

epoch:1, 	 Accuracy:0.9467


100%|██████████| 600/600 [00:04<00:00, 141.95it/s]
  2%|▏         | 13/600 [00:00<00:04, 127.72it/s]

epoch:2, 	 Accuracy:0.9454


100%|██████████| 600/600 [00:04<00:00, 142.39it/s]
  2%|▏         | 14/600 [00:00<00:04, 132.86it/s]

epoch:3, 	 Accuracy:0.9457


100%|██████████| 600/600 [00:04<00:00, 140.55it/s]
  5%|▍         | 28/600 [00:00<00:04, 134.49it/s]

epoch:4, 	 Accuracy:0.9464


100%|██████████| 600/600 [00:04<00:00, 142.32it/s]
  2%|▏         | 14/600 [00:00<00:04, 138.87it/s]

epoch:5, 	 Accuracy:0.9471


100%|██████████| 600/600 [00:04<00:00, 142.10it/s]
  2%|▏         | 14/600 [00:00<00:04, 132.00it/s]

epoch:6, 	 Accuracy:0.9471


100%|██████████| 600/600 [00:04<00:00, 142.56it/s]
  2%|▏         | 14/600 [00:00<00:04, 134.44it/s]

epoch:7, 	 Accuracy:0.9472


100%|██████████| 600/600 [00:04<00:00, 141.51it/s]
  5%|▍         | 28/600 [00:00<00:04, 135.76it/s]

epoch:8, 	 Accuracy:0.9469


100%|██████████| 600/600 [00:04<00:00, 142.66it/s]
  2%|▏         | 14/600 [00:00<00:04, 133.17it/s]

epoch:9, 	 Accuracy:0.9471


100%|██████████| 600/600 [00:04<00:00, 141.43it/s]
  2%|▏         | 14/600 [00:00<00:04, 135.94it/s]

epoch:10, 	 Accuracy:0.9470


100%|██████████| 600/600 [00:04<00:00, 142.46it/s]
  2%|▏         | 14/600 [00:00<00:04, 138.84it/s]

epoch:11, 	 Accuracy:0.9480


100%|██████████| 600/600 [00:04<00:00, 142.37it/s]
  2%|▏         | 14/600 [00:00<00:04, 132.80it/s]

epoch:12, 	 Accuracy:0.9478


100%|██████████| 600/600 [00:04<00:00, 142.30it/s]
  2%|▏         | 14/600 [00:00<00:04, 135.74it/s]

epoch:13, 	 Accuracy:0.9479


100%|██████████| 600/600 [00:04<00:00, 142.26it/s]
  2%|▏         | 13/600 [00:00<00:04, 129.94it/s]

epoch:14, 	 Accuracy:0.9479


100%|██████████| 600/600 [00:04<00:00, 142.52it/s]
  2%|▏         | 14/600 [00:00<00:04, 134.38it/s]

epoch:15, 	 Accuracy:0.9474


100%|██████████| 600/600 [00:04<00:00, 141.60it/s]
  2%|▏         | 14/600 [00:00<00:04, 133.29it/s]

epoch:16, 	 Accuracy:0.9478


100%|██████████| 600/600 [00:04<00:00, 143.05it/s]
  2%|▏         | 14/600 [00:00<00:04, 134.55it/s]

epoch:17, 	 Accuracy:0.9485


100%|██████████| 600/600 [00:04<00:00, 142.16it/s]
  2%|▏         | 14/600 [00:00<00:04, 132.26it/s]

epoch:18, 	 Accuracy:0.9489


100%|██████████| 600/600 [00:04<00:00, 138.29it/s]
  2%|▏         | 13/600 [00:00<00:04, 128.06it/s]

epoch:19, 	 Accuracy:0.9488


100%|██████████| 600/600 [00:04<00:00, 141.57it/s]
  2%|▏         | 14/600 [00:00<00:04, 132.65it/s]

epoch:20, 	 Accuracy:0.9481


100%|██████████| 600/600 [00:04<00:00, 142.23it/s]
  2%|▏         | 14/600 [00:00<00:04, 136.14it/s]

epoch:21, 	 Accuracy:0.9484


100%|██████████| 600/600 [00:04<00:00, 142.89it/s]
  2%|▏         | 14/600 [00:00<00:04, 132.57it/s]

epoch:22, 	 Accuracy:0.9497


100%|██████████| 600/600 [00:04<00:00, 141.94it/s]
  2%|▏         | 13/600 [00:00<00:04, 127.39it/s]

epoch:23, 	 Accuracy:0.9484


100%|██████████| 600/600 [00:04<00:00, 142.31it/s]


epoch:24, 	 Accuracy:0.9492
