In [19]:
import torch
import torch.nn.functional as F
import numpy as np
import random
def set_all_seeds(seed):
  random.seed(seed)
  # os.environ(“PYTHONHASHSEED”) = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True

In [23]:
set_all_seeds(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
B, C, D, H, W = 16, 2, 16, 16, 8
student_logits = torch.randn(B, C, D, H, W, device=device)
teacher_logits = torch.randn_like(student_logits)

def lwf_loss(student_logits, teacher_logits, T=10.0):
    loss_kl = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction='none'
        ).mean() * (T * T)

    return loss_kl

def test_lwf_loss(student_logits, teacher_logits, T=2.0):
    loss_kl = F.kl_div(
            F.log_softmax(student_logits / T, dim=1),
            F.softmax(teacher_logits / T, dim=1),
            reduction='batchmean'
        ) * (T * T)
    return loss_kl

Using device: cuda


In [27]:
print('lwf_loss t=2', lwf_loss(student_logits, teacher_logits, T=2.0))
print('lwf_loss t=10', lwf_loss(student_logits, teacher_logits, T=10.0))
print('lwf_loss t=10000', lwf_loss(student_logits, teacher_logits, T=10000.0))
print('test_lwf_loss', test_lwf_loss(student_logits, teacher_logits))

lwf_loss t=2 tensor(0.2246, device='cuda:0')
lwf_loss t=10 tensor(0.2485, device='cuda:0')
lwf_loss t=10000 tensor(1.1715, device='cuda:0')
test_lwf_loss tensor(919.9381, device='cuda:0')
