# CLS

## 1. Setup

In [1]:
!nvidia-smi

Sun Dec 12 21:58:18 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.45.01    Driver Version: 455.45.01    CUDA Version: 11.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  TITAN RTX           On   | 00000000:01:00.0 Off |                  N/A |
| 41%   22C    P8    16W / 280W |      1MiB / 24219MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [2]:
import os
import gc
import time
import copy
import math
import hydra
import GPUtil 
import itertools
from tqdm.auto import tqdm
from typing import List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F

import pytorch_lightning as pl

from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig
from transformers import get_scheduler
from transformers import BatchEncoding
from transformers.data.data_collator import DataCollatorForWholeWordMask
from datasets import load_dataset, concatenate_datasets

In [3]:
with hydra.initialize('../configs'):
    config = hydra.compose('config.yaml', overrides=['working_dir=../', 'model.mlm=false', 'data.batch_size=16'])

## 2. Data

In [4]:
class DataModule(pl.LightningDataModule):
    def __init__(self, tokenizer, batch_size=8):
        super().__init__()
        self.tokenizer = tokenizer
        self.batch_size=batch_size
        
    def setup(self, stage=None):
        self.dataset = load_dataset('text', data_files=os.path.join('../data', 'kowiki.txt'))['train']
        self.dataset.set_transform(lambda batch: transform(batch, self.tokenizer, 512))
        self.dataset = self.dataset.train_test_split(test_size=0.01)
        self.train_dataset, self.eval_dataset = self.dataset['train'], self.dataset['test']
        

    def collate_fn(self, batch):
        batch = BatchEncoding(batch)
        batch['attention_mask'] = batch.input_ids.ne(self.tokenizer.pad_token_id).float()
        return batch

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)
    
    def validation_dataloader(self):
        return torch.utils.data.DataLoader(self.eval_dataset, batch_size=self.batch_size, shuffle=False)

    
def transform(batch, tokenizer, max_length):
    new_batch = []
    for text in batch['text']:
        text = slice_text(text)
        new_batch.append(text)
    
    return tokenizer(new_batch, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')


def slice_text(text, max_char_length=1024):
    if len(text) > max_char_length:
        idx = np.random.randint(low=0, high=len(text)-max_char_length)
        text = text[idx : idx+max_char_length]
    return text

## 3. Model

In [5]:
def select_layers(model: torch.nn.Module, indices: List[int]):
    model.encoder.layer = nn.ModuleList([l for i, l in enumerate(model.encoder.layer) if i in indices])
    return model

def select_indices_from_embedding(embedding: torch.Tensor, num_features: int = 384, corr_threshold: float = 0.3):
    corr = torch.corrcoef(embedding.transpose(-1, -2))
    corr = corr - torch.eye(corr.size(0))
    
    removed_indices = []
    for idx in range(corr.size(0)):
        most_sim_idx = corr[idx].argmax()
        if corr[idx][most_sim_idx] > corr_threshold and idx > most_sim_idx:
            removed_indices.append(idx)
            
    norm = embedding.norm(dim=0)
    norm_indices = (-norm).argsort()

    indices = [int(i) for i in norm_indices if i not in removed_indices]
    indices = indices[:num_features]
    return indices

def select_weight(weight, indices, dim):
    if type(dim) == int:
        indices = torch.tensor(indices)
        return torch.index_select(weight, dim, indices.to(weight.device))
    
    else:
        for d in dim:
            weight = select_weight(weight, indices, d)
        return weight
    

def select_embedding(embedding, indices):
    embedding.weight.data = select_weight(embedding.weight.data, indices, dim=1)
    embedding.embedding_dim = len(indices)
    return embedding

def select_layernorm(layernorm, indices):
    layernorm.weight.data = select_weight(layernorm.weight.data, indices, dim=0)
    layernorm.bias.data = select_weight(layernorm.bias.data, indices, dim=0)
    layernorm.normalized_shape = (len(indices), )
    return layernorm

def select_linear(linear, indices, dims=[0, 1]):
    linear.weight.data = select_weight(linear.weight.data, indices, dim=dims)
    if 0 in dims:
        linear.bias.data = select_weight(linear.bias.data, indices, dim=0)
        
    if 1 in dims:
        linear.in_features = len(indices)
    if 0 in dims:
        linear.out_features = len(indices)
    return linear

def select_bert_embeddings(bert_embeddings, indices):
    bert_embeddings.word_embeddings = select_embedding(bert_embeddings.word_embeddings, indices)
    bert_embeddings.position_embeddings = select_embedding(bert_embeddings.position_embeddings, indices)
    bert_embeddings.token_type_embeddings = select_embedding(bert_embeddings.token_type_embeddings, indices)
    bert_embeddings.LayerNorm = select_layernorm(bert_embeddings.LayerNorm, indices)
    return bert_embeddings


def select_bert_layer(bert_layer, indices):
    bert_layer.attention.self.all_head_size = len(indices)
    bert_layer.attention.self.attention_head_size = len(indices) // bert_layer.attention.self.num_attention_heads
    
    bert_layer.attention.self.query = select_linear(bert_layer.attention.self.query, indices)
    bert_layer.attention.self.key = select_linear(bert_layer.attention.self.key, indices)
    bert_layer.attention.self.value = select_linear(bert_layer.attention.self.value, indices)
    
    bert_layer.attention.output.dense = select_linear(bert_layer.attention.output.dense, indices)
    bert_layer.attention.output.LayerNorm = select_layernorm(bert_layer.attention.output.LayerNorm, indices)
    
    bert_layer.intermediate.dense = select_linear(bert_layer.intermediate.dense, indices, dims=[1])
    bert_layer.output.dense = select_linear(bert_layer.output.dense, indices, dims=[0])
    bert_layer.output.LayerNorm = select_layernorm(bert_layer.output.LayerNorm, indices)
    return bert_layer

def select_bert_pooler(bert_pooler, indices):
    bert_pooler.dense = select_linear(bert_pooler.dense, indices)
    return bert_pooler

def select_bert_model(bert_model, layer_indices, weight_indices):
    bert_model = select_layers(bert_model, layer_indices)
    bert_model.embeddings = select_bert_embeddings(bert_model.embeddings, weight_indices)
    bert_model.encoder.layer = nn.ModuleList([select_bert_layer(l, weight_indices) for l in bert_model.encoder.layer])
    bert_model.pooler = select_bert_pooler(bert_model.pooler, weight_indices)
    return bert_model

In [6]:
def to_distill(model):
    model.base_model.encoder.layer[0].attention.self.__class__._forward = bert_self_attention_forward
    for layer in model.base_model.encoder.layer:
        layer.attention.self.forward = layer.attention.self._forward
    
    for param in model.parameters():
        param.requires_grad = False
    return model


def bert_self_attention_forward(
    self,
    hidden_states,
    attention_mask=None,
    head_mask=None,
    encoder_hidden_states=None,
    encoder_attention_mask=None,
    past_key_value=None,
    output_attentions=False,
):
    mixed_query_layer = self.query(hidden_states)
    mixed_key_layer = self.key(hidden_states)
    mixed_value_layer = self.value(hidden_states)
    
    query_layer = self.transpose_for_scores(mixed_query_layer)
    key_layer = self.transpose_for_scores(mixed_key_layer)
    value_layer = self.transpose_for_scores(mixed_value_layer)
    
    self.q = mixed_query_layer # (Batch, Seq, Dim)
    self.k = mixed_key_layer # (Batch, Seq, Dim)
    self.v = mixed_value_layer # (Batch, Seq, Dim)

    if self.is_decoder:
        past_key_value = (key_layer, value_layer)

    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

    if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
        seq_length = hidden_states.size()[1]
        position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
        position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
        distance = position_ids_l - position_ids_r
        positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
        positional_embedding = positional_embedding.to(dtype=query_layer.dtype)  # fp16 compatibility

        if self.position_embedding_type == "relative_key":
            relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores
        elif self.position_embedding_type == "relative_key_query":
            relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
            relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
            attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key

    attention_scores = attention_scores / math.sqrt(self.attention_head_size)
    if attention_mask is not None:
        attention_scores = attention_scores + attention_mask

    attention_probs = nn.Softmax(dim=-1)(attention_scores)
    attention_probs = self.dropout(attention_probs)

    if head_mask is not None:
        attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)

    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
    if self.is_decoder:
        outputs = outputs + (past_key_value,)
    return outputs


def get_qkvs(model):
    attns = [l.attention.self for l in model.base_model.encoder.layer]
    qkvs = [{'q': a.q, 'k': a.k, 'v': a.v} for a in attns]    
    return qkvs

def transpose_for_scores(h, num_heads):
    batch_size, seq_length, dim = h.size()
    head_size = dim // num_heads
    h = h.view(batch_size, seq_length, num_heads, head_size)
    return h.permute(0, 2, 1, 3) # (batch, num_heads, seq_length, head_size)


def attention(h1, h2, num_heads, attention_mask=None):
    assert h1.size() == h2.size()
    head_size = h1.size(-1) // num_heads
    h1 = transpose_for_scores(h1, num_heads) # (batch, num_heads, seq_length, head_size)
    h2 = transpose_for_scores(h2, num_heads) # (batch, num_heads, seq_length, head_size)

    attn = torch.matmul(h1, h2.transpose(-1, -2)) # (batch_size, num_heads, seq_length, seq_length)
    attn = attn / math.sqrt(head_size)
    if attention_mask is not None:
        attention_mask = attention_mask[:, None, None, :]
        attention_mask = (1 - attention_mask) * -10000.0
        attn = attn + attention_mask

    return attn


def kl_div_loss(s, t, temperature=1.):
    if len(s.size()) != 2:
        s = s.view(-1, s.size(-1))
        t = t.view(-1, t.size(-1))

    s = F.log_softmax(s / temperature, dim=-1)
    t = F.softmax(t / temperature, dim=-1)
    return F.kl_div(s, t, reduction='batchmean')

def minilm_loss(t, s, num_relation_heads, attention_mask=None):
    attn_t = attention(t, t, num_relation_heads, attention_mask)
    attn_s = attention(s, s, num_relation_heads, attention_mask)
    loss = kl_div_loss(attn_s, attn_t)
    return loss

## 4. Main

In [7]:
tokenizer = AutoTokenizer.from_pretrained('klue/bert-base')

In [8]:
data_module = DataModule(tokenizer)
data_module.setup()
loader = iter(data_module.train_dataloader())

Using custom data configuration default-82324f4e586d6530
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-82324f4e586d6530/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


  0%|          | 0/1 [00:00<?, ?it/s]

In [9]:
batch = next(loader)

In [10]:
teacher = AutoModel.from_pretrained('klue/bert-base', output_hidden_states=True, output_attentions=True)

Some weights of the model checkpoint at klue/bert-base were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
config = AutoConfig.from_pretrained('klue/bert-base', num_hidden_layers=3, hidden_size=384)
student = AutoModel.from_config(config)
student = to_distill(student)

In [12]:
teacher_outputs = teacher(**batch)
student_outputs = student(**batch)

In [17]:
teacher_cls = teacher_outputs.last_hidden_state[:4, 0]
student_cls = student_outputs.last_hidden_state[:4, 0]

In [65]:
def transpose_for_scores(h, num_heads):
    batch_size, seq_length, dim = h.size()
    head_size = dim // num_heads
    h = h.view(batch_size, seq_length, num_heads, head_size)
    return h.permute(0, 2, 1, 3) # (batch, num_heads, seq_length, head_size)


def attention(h1, h2, num_heads, attention_mask=None):
    assert h1.size() == h2.size()
    head_size = h1.size(-1) // num_heads
    h1 = transpose_for_scores(h1, num_heads) # (batch, num_heads, seq_length, head_size)
    h2 = transpose_for_scores(h2, num_heads) # (batch, num_heads, seq_length, head_size)

    attn = torch.matmul(h1, h2.transpose(-1, -2)) # (batch_size, num_heads, seq_length, seq_length)
    attn = attn / math.sqrt(head_size)
    if attention_mask is not None:
        attention_mask = attention_mask[:, None, None, :]
        attention_mask = (1 - attention_mask) * -10000.0
        attn = attn + attention_mask

    return attn


def kl_div_loss(s, t, temperature):
    if len(s.size()) != 2:
        s = s.view(-1, s.size(-1))
        t = t.view(-1, t.size(-1))

    s = F.log_softmax(s / temperature, dim=-1)
    t = F.softmax(t / temperature, dim=-1)
    return F.kl_div(s, t, reduction='batchmean')


def minilm_loss(t, s, num_relation_heads, attention_mask=None, temperature=1.0):
    attn_t = attention(t, t, num_relation_heads, attention_mask)
    attn_s = attention(s, s, num_relation_heads, attention_mask)
    loss = kl_div_loss(attn_s, attn_t, temperature=temperature)
    return loss


In [67]:
teacher_cls = teacher_cls.unsqueeze(0)
student_cls = student_cls.unsqueeze(0)

In [68]:
minilm_loss(teacher_cls, student_cls, 48, None, 1)

tensor(0.2475, grad_fn=<DivBackward0>)

In [80]:
attention(teacher_cls, teacher_cls, 48).softmax(dim=-1)[0][5]

tensor([[0.4406, 0.1450, 0.3133, 0.1011],
        [0.1187, 0.6125, 0.1160, 0.1528],
        [0.2786, 0.1261, 0.4840, 0.1113],
        [0.0363, 0.0671, 0.0450, 0.8516]], grad_fn=<SelectBackward0>)

In [48]:
def pdist(e, squared=False, eps=1e-12):
    e_square = e.pow(2).sum(dim=1)
    prod = e @ e.t()
    res = (e_square.unsqueeze(1) + e_square.unsqueeze(0) - 2 * prod).clamp(min=eps)

    if not squared:
        res = res.sqrt()

    res = res.clone()
    res[range(len(e)), range(len(e))] = 0
    return res

In [50]:
pdist(student_cls)

tensor([[0.0000, 9.0273, 9.9646, 9.9388],
        [9.0273, 0.0000, 9.6253, 8.2866],
        [9.9646, 9.6253, 0.0000, 9.3097],
        [9.9388, 8.2866, 9.3097, 0.0000]])

In [51]:
td = torch.cdist(teacher_cls, teacher_cls, p=2)
sd = torch.cdist(student_cls, student_cls, p=2)

In [56]:
mean_td = td[td>0].mean()
td /= mean_td

In [57]:
def cdist(v):
    d = torch.cdist(v, v, p=2)
    m = d[d>0].mean()
    return d / m

In [62]:
td = cdist(teacher_cls)
sd = cdist(student_cls)

In [64]:
F.smooth_l1_loss(sd, td, reduction='mean')

tensor(0.0039, grad_fn=<SmoothL1LossBackward0>)

In [26]:
torch.cdist(student_cls, student_cls)

tensor([[0.0000, 9.0273, 9.9646, 9.9388],
        [9.0273, 0.0000, 9.6253, 8.2866],
        [9.9646, 9.6253, 0.0000, 9.3097],
        [9.9388, 8.2866, 9.3097, 0.0000]])

In [20]:
distance_loss_fn(student_cls, teacher_cls)

tensor(2.2821, grad_fn=<MseLossBackward0>)

In [31]:
dist_t = torch.cdist(t, t)
dist_s = torch.cdist(s, s)

In [34]:
dist_t / t.size(-1)

tensor([[0.0000, 0.0263, 0.0257, 0.0248, 0.0221, 0.0256, 0.0277, 0.0276],
        [0.0263, 0.0000, 0.0294, 0.0313, 0.0246, 0.0247, 0.0282, 0.0227],
        [0.0257, 0.0294, 0.0000, 0.0257, 0.0260, 0.0276, 0.0308, 0.0331],
        [0.0248, 0.0313, 0.0257, 0.0000, 0.0272, 0.0293, 0.0333, 0.0340],
        [0.0221, 0.0246, 0.0260, 0.0272, 0.0000, 0.0257, 0.0277, 0.0268],
        [0.0256, 0.0247, 0.0276, 0.0293, 0.0257, 0.0000, 0.0306, 0.0289],
        [0.0277, 0.0282, 0.0308, 0.0333, 0.0277, 0.0306, 0.0000, 0.0321],
        [0.0276, 0.0227, 0.0331, 0.0340, 0.0268, 0.0289, 0.0321, 0.0000]])

In [35]:
dist_s / s.size(-1)

tensor([[0.0000, 0.0231, 0.0237, 0.0232, 0.0216, 0.0235, 0.0227, 0.0288],
        [0.0231, 0.0000, 0.0218, 0.0209, 0.0211, 0.0210, 0.0211, 0.0265],
        [0.0237, 0.0218, 0.0000, 0.0220, 0.0215, 0.0220, 0.0225, 0.0267],
        [0.0232, 0.0209, 0.0220, 0.0000, 0.0203, 0.0220, 0.0222, 0.0272],
        [0.0216, 0.0211, 0.0215, 0.0203, 0.0000, 0.0216, 0.0209, 0.0257],
        [0.0235, 0.0210, 0.0220, 0.0220, 0.0216, 0.0000, 0.0226, 0.0277],
        [0.0227, 0.0211, 0.0225, 0.0222, 0.0209, 0.0226, 0.0000, 0.0266],
        [0.0288, 0.0265, 0.0267, 0.0272, 0.0257, 0.0277, 0.0266, 0.0000]])

In [24]:
teacher_outputs = teacher(**batch)
student_outputs = student(**batch)

In [53]:
teacher_cls = teacher_outputs.last_hidden_state[:, 0, :]
student_cls = student_outputs.last_hidden_state[:, 0, :]

teacher_qkv = get_qkvs(teacher)[-1]# (batch, head, seq, head_dim)
student_qkv = get_qkvs(student)[-1] # (batch, head, seq, head_dim)

loss_q = distance_loss_fn(student_qkv['q'], teacher_qkv['q'])
loss_k = distance_loss_fn(student_qkv['k'], teacher_qkv['k'])
loss_v = distance_loss_fn(student_qkv['v'], teacher_qkv['v'])
loss_cls = distance_loss_fn(student_cls, teacher_cls)

In [54]:
loss_q, loss_k, loss_v, loss_cls

(tensor(8.5106), tensor(8.9481), tensor(3.4897), tensor(0.9064))

In [24]:
teacher_cls = teacher_outputs.last_hidden_state[:, 0, :]
student_cls = student_outputs.last_hidden_state[:, 0, :]

In [87]:
torch.cdist(teacher_cls, teacher_cls)

tensor([[ 0.0000, 23.8793, 24.8112, 23.0736, 21.2155, 22.8808, 22.5494, 22.5182],
        [23.8793,  0.0000,  9.9119, 19.6570, 23.2631, 21.9032, 19.6912, 22.2391],
        [24.8112,  9.9119,  0.0000, 20.7842, 23.0034, 22.7179, 20.4372, 23.2159],
        [23.0736, 19.6570, 20.7842,  0.0000, 20.2961, 19.3592, 18.5026, 20.6066],
        [21.2155, 23.2631, 23.0034, 20.2961,  0.0000, 23.4661, 24.1644, 22.4794],
        [22.8808, 21.9032, 22.7179, 19.3592, 23.4661,  0.0000, 19.8730, 21.4704],
        [22.5494, 19.6912, 20.4372, 18.5026, 24.1644, 19.8730,  0.0000, 20.0151],
        [22.5182, 22.2391, 23.2159, 20.6066, 22.4794, 21.4704, 20.0151,  0.0000]])

In [84]:
def distance_loss_fn(s, t):
    dist_t = torch.cdist(t, t)
    dist_s = torch.cdist(s, s)
    return F.huber_loss(dist_s, dist_t)

In [81]:
dist_loss(teacher_cls, student_cls)

tensor(6.2299, grad_fn=<HuberLossBackward0>)

In [90]:
torch.cdist(teacher_qkvs[-1]['q'], teacher_qkvs[-1]['q']).size()

torch.Size([8, 512, 512])

In [82]:
dist_loss(teacher_qkvs[-1]['q'], student_qkvs[-1]['q'])

tensor(23.0016)

In [77]:
teacher_cls_dist = torch.cdist(teacher_cls, teacher_cls)
student_cls_

tensor([[ 0.0000, 23.8793, 24.8112, 23.0736, 21.2155, 22.8808, 22.5494, 22.5182],
        [23.8793,  0.0000,  9.9119, 19.6570, 23.2631, 21.9032, 19.6912, 22.2391],
        [24.8112,  9.9119,  0.0000, 20.7842, 23.0034, 22.7179, 20.4372, 23.2159],
        [23.0736, 19.6570, 20.7842,  0.0000, 20.2961, 19.3592, 18.5026, 20.6066],
        [21.2155, 23.2631, 23.0034, 20.2961,  0.0000, 23.4661, 24.1644, 22.4794],
        [22.8808, 21.9032, 22.7179, 19.3592, 23.4661,  0.0000, 19.8730, 21.4704],
        [22.5494, 19.6912, 20.4372, 18.5026, 24.1644, 19.8730,  0.0000, 20.0151],
        [22.5182, 22.2391, 23.2159, 20.6066, 22.4794, 21.4704, 20.0151,  0.0000]])

In [36]:
teacher_cls_sim = sim_matrix(teacher_cls, teacher_cls)
student_cls_sim = sim_matrix(student_cls, student_cls)

In [61]:
teacher_qkvs = get_qkvs(model)
student_qkvs = get_qkvs(student)

In [62]:
teacher_dist = torch.cdist(teacher_qkvs[-1]['q'], teacher_qkvs[-1]['q'])
student_dist = torch.cdist(student_qkvs[-1]['q'], student_qkvs[-1]['q'])

In [76]:
F.huber_loss(student_dist, teacher_dist)

tensor(23.0016)

In [74]:
F.mse_loss(student_dist, teacher_dist) * 0.1

tensor(59.7368)

In [None]:
torch.cdist()

In [48]:
torch.cdist(teacher_cls, teacher_cls)

tensor([[ 0.0000, 23.8793, 24.8112, 23.0736, 21.2155, 22.8808, 22.5494, 22.5182],
        [23.8793,  0.0000,  9.9119, 19.6570, 23.2631, 21.9032, 19.6912, 22.2391],
        [24.8112,  9.9119,  0.0000, 20.7842, 23.0034, 22.7179, 20.4372, 23.2159],
        [23.0736, 19.6570, 20.7842,  0.0000, 20.2961, 19.3592, 18.5026, 20.6066],
        [21.2155, 23.2631, 23.0034, 20.2961,  0.0000, 23.4661, 24.1644, 22.4794],
        [22.8808, 21.9032, 22.7179, 19.3592, 23.4661,  0.0000, 19.8730, 21.4704],
        [22.5494, 19.6912, 20.4372, 18.5026, 24.1644, 19.8730,  0.0000, 20.0151],
        [22.5182, 22.2391, 23.2159, 20.6066, 22.4794, 21.4704, 20.0151,  0.0000]])

In [40]:
teacher_cls_sim

tensor([[1.0000, 0.5393, 0.5030, 0.5679, 0.6387, 0.5788, 0.5872, 0.5904],
        [0.5393, 1.0000, 0.9199, 0.6831, 0.5612, 0.6101, 0.6819, 0.5964],
        [0.5030, 0.9199, 1.0000, 0.6461, 0.5713, 0.5809, 0.6577, 0.5605],
        [0.5679, 0.6831, 0.6461, 1.0000, 0.6645, 0.6940, 0.7179, 0.6519],
        [0.6387, 0.5612, 0.5713, 0.6645, 1.0000, 0.5554, 0.5243, 0.5904],
        [0.5788, 0.6101, 0.5809, 0.6940, 0.5554, 1.0000, 0.6775, 0.6254],
        [0.5872, 0.6819, 0.6577, 0.7179, 0.5243, 0.6775, 1.0000, 0.6715],
        [0.5904, 0.5964, 0.5605, 0.6519, 0.5904, 0.6254, 0.6715, 1.0000]])

In [47]:
tokenizer.decode(batch['input_ids'][5])

'[CLS] 이산화 셀레늄 ( 는 셀레늄과 산소로 이루어진 화합물이다. 화학식은 SeO2이다. 사슬 모양의 거대 분자이다. 상온에서는 백색의 흡습성 결정으로 존재한다. 액체는 등청색이며, 기체는 황록색이다. [UNK] 승화한다. 비중은 3. 954이다. 물, 에탄올, 아세트산에 잘 녹는다. 환원되기 쉬운 물질이다. 공기 중의 먼지에 의해서도 분해된다. 물에 녹일 경우 아셀렌산을 형성하며 녹는다. 셀레늄을 공기 또는 산소 중에서 연소시켜 얻을 수 있다. 이산화 셀레늄은 주로 다음과 같은 용도로 사용된다. 유기물의 합성에서 산화제나 촉매로 사용된다. 셀레늄 화합물 합성의 원료로 사용된다. 이산화 셀레늄은 독성이 있는 화합물이다. 만성적으로 접촉할 경우 창백함, 설태, 위장 질환, 신경쇠약 등을 일으킬 수 있다. 이산화 셀레늄과 같이 물에 녹을 수 있는 셀레늄 화합물은 독성이 강한데, 이는 물질이 체내의 설파이드릴 효소를 공격하기 때문인 것으로 추정된다 化 學 大 [UNK] 典 [UNK] 集 [UNK] 員 會 편, 성용길, 김창홍 역, 《 화학대사전 》, 서울 世 和, 2001. https : / / web. archive. org / web / 20080608083320 / http : / / www. jtbaker. com / msds / englishhtml / s1130. htm 분류 : 셀레늄 화합물 분류 : 산화물 분류 : 산성 산화물 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

In [39]:
student_cls_sim

tensor([[1.0000, 0.8776, 0.8585, 0.8832, 0.8530, 0.8797, 0.8698, 0.8731],
        [0.8776, 1.0000, 0.8613, 0.8838, 0.8692, 0.8800, 0.8721, 0.8759],
        [0.8585, 0.8613, 1.0000, 0.8597, 0.8437, 0.8624, 0.8521, 0.8508],
        [0.8832, 0.8838, 0.8597, 1.0000, 0.8629, 0.8833, 0.8891, 0.8800],
        [0.8530, 0.8692, 0.8437, 0.8629, 1.0000, 0.8797, 0.8810, 0.8672],
        [0.8797, 0.8800, 0.8624, 0.8833, 0.8797, 1.0000, 0.8844, 0.8748],
        [0.8698, 0.8721, 0.8521, 0.8891, 0.8810, 0.8844, 1.0000, 0.8748],
        [0.8731, 0.8759, 0.8508, 0.8800, 0.8672, 0.8748, 0.8748, 1.0000]],
       grad_fn=<MmBackward0>)

In [35]:
F.cosine_embedding_loss(student_cls, student_cls, target=teacher_cls_sim)

RuntimeError: 0D or 1D target tensor expected, multi-target not supported

In [30]:
F.cosine_similarity(teacher_cls, teacher_cls, dim=1)

tensor([1., 1., 1., 1., 1., 1., 1., 1.])