In [1]:
# GPT-4 生成的代码

## GPT Generation

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
'''
DataLoader是torch中用于加载数据的一个重要类，它将dataset划分成小批次batch size，并提供迭代器接受以方便训练时batch-wise读取数据。
return: DataLoader函数返回一个包含两个元素的元组(data, target).
    - data，当前batch input data，通常是一个四维tensor, shape=(batch_size, channels, height, width)
    - target, 当前batch label，通常是一个一维tensor, shape=(batch_size,)
'''

### teacher model

In [5]:
class TeacherNet(nn.Module):
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 400)
        self.fc2 = nn.Linear(400, 200)
        self.fc3 = nn.Linear(200, 10)
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

### student model

In [7]:
class StudentNet(nn.Module):
    def __init__(self):
        super(StudentNet, self).__init__()
        self.fc1 = nn.Linear(28*28, 100)
        self.fc2 = nn.Linear(100, 10)
    
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

### loss function

In [9]:
'''
KLDivLoss，用于计算KL散度(Kullback-Leibler)，用于衡量两个概率分布之间差异的非对称度量。
D_KL(P||Q)=Σpi log(pi/qi), pi是真实概率，qi是近似概率
KD Loss = α*T^2*KL_loss + (1-α)*cross_entropy_loss
'''
# 损失函数
def distillation_loss(student_logits, teacher_logits, targets, temperature, alpha):
    soft_loss = nn.KLDivLoss()(torch.log_softmax(student_logits / temperature, dim=1),  # 学生output 与 teacher output
                              torch.softmax(teacher_logits / temperature, dim=1)) * (temperature ** 2) # 指数运算
    hard_loss = nn.CrossEntropyLoss()(student_logits, targets)  # 学生output 与 target
    return alpha * soft_loss + (1-alpha) * hard_loss

### training teacher model

In [25]:
def train_teacher():
    teacher_model.train()
    optimizer_teacher = optim.Adam(teacher_model.parameters(), lr=0.001)
    for epoch in range(5):
        for batch_idx, (data, target) in enumerate(train_dataloader):
            optimizer_teacher.zero_grad()
            output = teacher_model(data)
            loss = nn.CrossEntropyLoss()(output, target)
            loss.backward()
            optimizer_teacher.step()
        print(f"Teacher Epoch: {epoch+1}, Loss: {loss.item()}")

### training student model

In [22]:
def train_student(temperature, alpha):
    teacher_model.eval()
    student_model.train()
    for epoch in range(5):
        for batch_idx, (data, target) in enumerate(train_dataloader):
            optimizer.zero_grad()
            teacher_output = teacher_model(data).detach()
            student_output = student_model(data)
            loss = distillation_loss(student_output, teacher_output, target, temperature, alpha)
            loss.backward()
            optimizer.step()
        print(f"Student Epoch: {epoch+1}, Loss: {loss.item()}")

### test model

In [28]:
def test_model(model):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_dataloader:
            output = model(data)
            test_loss += nn.CrossEntropyLoss(reduction='sum')(output, target).item()  # reduction=sum是对所有所有样本loss求和(即不取平均)
            pred = output.argmax(dim=1, keepdim=True)  # argmax返回最大值索引index; keepdim=True表示保持tensor维度。
            correct += pred.eq(target.view_as(pred)).sum().item()
    
    test_loss /= len(test_dataloader.dataset)  # 平均test loss
    accuracy = 100. * correct / len(test_dataloader.dataset)  # 平均accuracy
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_dataloader.dataset)} ({accuracy:.2f}%)')

### main

In [18]:
# 数据准备
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081))
])
# dataset download
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# create dataloader
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=1000, shuffle=True) 

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 9912422/9912422 [00:02<00:00, 4955023.13it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|████████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 148332.99it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:02<00:00, 554616.50it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 2279589.42it/s]

Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw






In [None]:
# initialize model
teacher_model = TeacherNet()
student_model = StudentNet()

# define optimizer
optimizer = optim.Adam(student_model.parameters(), lr=0.001)

# training and testing
train_teacher()
train_student(temperature=2.0, alpha=0.7)

In [29]:
test_model(student_model)

Test set: Average loss: 0.0925, Accuracy: 9739/10000 (97.39%)
