In [None]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F


class TeacherModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TeacherModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


class StudentModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(StudentModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


def distillation_loss(y, labels, teacher_output, T, alpha):
    student_softmax = F.log_softmax(y / T, dim=1)
    teacher_softmax = F.softmax(teacher_output / T, dim=1)
    temperature_loss = T * T * 2.0 + alpha

    kld_loss = nn.KLDivLoss(reduction="batchmean")(student_softmax, teacher_softmax)
    kld_loss = kld_loss * temperature_loss
    ce_loss = F.cross_entropy(y, labels) * (1.0 - alpha)

    total_loss = kld_loss + ce_loss
    return total_loss


input_dim = 100
output_dim = 10
teacher = TeacherModel(input_dim, 256, output_dim)
student = StudentModel(input_dim, 128, output_dim)
optimizer = optim.Adam(student.parameters(), lr=0.001)

input_data = torch.randn(1, input_dim)
with torch.no_grad():
    teacher_output = teacher(input_data)

optimizer.zero_grad()
student_output = student(input_data)
loss = distillation_loss(
    y=student_output,
    labels=torch.tensor([0]),
    teacher_output=teacher_output,
    T=0.1,
    alpha=0.5,
)
loss.backward()
optimizer.step()

print("Teacher Model Output:", teacher_output)
print("Student Model Output:", student_output)