## 1. Setup

In [43]:
import math
import copy
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModel, AutoTokenizer, AutoConfig
from transformers import BatchEncoding
from datasets import load_dataset

## 2. Models

In [None]:
teacher = AutoModel.from_pretrained('klue/bert-base').cuda()

student = AutoModel.from_pretrained('../ckpt/transformers/').cuda()

config = AutoConfig.from_pretrained('../ckpt/transformers/')
random = AutoModel.from_config(config).cuda()

In [None]:
for param in teacher.parameters():
    param.requires_grad = False

In [None]:
for param in student.parameters():
    param.requires_grad = False

In [None]:
for param in random.parameters():
    param.requires_grad = False

## 3. Embedding

In [None]:
del t, s, r
del td, sd, rd

In [6]:
t = teacher.embeddings.word_embeddings.weight[1:, 1:]
t = torch.matmul(t, t.transpose(-1, -2))
# t = t.softmax(dim=1)

In [7]:
s = student.embeddings.word_embeddings.weight[1:, 1:]
s = torch.matmul(s, s.transpose(-1, -2))
# s = s.softmax(dim=1)

In [8]:
r = random.embeddings.word_embeddings.weight[1:, 1:]
r = torch.matmul(r, r.transpose(-1, -2))
# r = r.softmax(dim=1)

In [24]:
s[i:i+1].size()

torch.Size([1, 31999])

In [21]:
l = 0
for i in tqdm(range(len(t))):
    _l = F.kl_div(F.log_softmax(s[i:i+1], dim=-1), F.softmax(t[i:i+1], dim=-1))
    l += _l

  0%|          | 0/31999 [00:00<?, ?it/s]

In [19]:
_l

tensor(1.1323e-06, device='cuda:0')

In [22]:
l

tensor(0.0255, device='cuda:0')

In [20]:
l

tensor(0.0175, device='cuda:0')

## 4. Hidden States

In [None]:
tokenizer = AutoTokenizer.from_pretrained('klue/bert-base')

In [32]:
dataset = load_dataset('text', data_files='../data/kowiki.txt')['train']

Using custom data configuration default-82324f4e586d6530
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-82324f4e586d6530/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


  0%|          | 0/1 [00:00<?, ?it/s]

In [46]:
dataset.set_transform(lambda batch: tokenizer(batch['text'], return_tensors='pt', max_length=512, truncation=True, padding='max_length', return_token_type_ids=False))

In [47]:
loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=False)

In [48]:
d = next(iter(loader))
d = BatchEncoding(d)
d = d.to('cuda')

In [56]:
to = teacher(**d, output_attentions=True, output_hidden_states=True)
so = student(**d, output_attentions=True, output_hidden_states=True)
ro = random(**d, output_attentions=True, output_hidden_states=True)

In [63]:
to.last_hidden_state.size()

torch.Size([4, 512, 768])

In [62]:
so.last_hidden_state.size()

torch.Size([4, 512, 384])