In [1]:
import torch.nn as nn

class SentimentTransformer(nn.Module):
    def __init__(self, transformer, dropout, num_labels):
        super().__init__()

        self._num_labels = num_labels
        self._transformer = transformer
        self._transformer_output_size = (
            self._transformer.config.emb_dim
            if hasattr(self._transformer.config, "emb_dim")
            else self._transformer.config.hidden_size
        )
        self._head_dropout = nn.Dropout(dropout)
        self._classification_head = nn.Linear(
            self._transformer_output_size,
            self._num_labels
        )
        
        self._loss = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels=None):
        transformer_output = self._transformer(input_ids, attention_mask=attention_mask)[0][:, 0]
        transformer_output = self._head_dropout(transformer_output)
        logits = self._classification_head(transformer_output)

        loss = None
        if labels is not None:
            loss = self._loss(logits, labels)
      
        return (loss, logits) if loss is not None else logits


In [2]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/98/87/ef312eef26f5cecd8b17ae9654cdd8d1fae1eb6dbd87257d6d73c128a4d0/transformers-4.3.2-py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 16.4MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 49.3MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/fd/5b/44baae602e0a30bcc53fbdbc60bd940c15e143d252d658dfdefce736ece5/tokenizers-0.10.1-cp36-cp36m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 56.3MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893261 sha256=6219bc070a

In [7]:
from transformers import AutoModel

MODEL_NAME = "bert-base-cased"
transformer = AutoModel.from_pretrained(MODEL_NAME)
teacher_model = SentimentTransformer(
    transformer=transformer,
    dropout=0.1,
    num_labels=2,
)

In [8]:
import torch
teacher_model_weights = torch.load("drive/MyDrive/imdb_bert.th")
teacher_model.load_state_dict(teacher_model_weights, strict=False)

<All keys matched successfully>

## Инициализация архитектуры student-модели

In [9]:
from transformers import BertConfig, BertModel
student_model_transformer_config = BertConfig.from_pretrained(MODEL_NAME)
student_model_transformer_config.num_hidden_layers = 6
student_transformer = BertModel(student_model_transformer_config)
student_model = SentimentTransformer(
    transformer=student_transformer,
    dropout=0.1,
    num_labels=2,
)

## Перенос весов с модели-учителя в модель-ученика

In [10]:
def transfer_weights(teacher_model, student_model) -> None:
    teacher_model_weights = teacher_model.state_dict()
    student_model_weights = student_model.state_dict()

    for name, weights in student_model_weights.items():
        if (
            name.startswith("_transformer.embeddings")
            or name.startswith("_transformer.pooler")
            or name.startswith("_classification_head")
        ):
            student_model_weights[name] = teacher_model_weights[name]

    for name, weights in student_model_weights.items():
        for teacher_index in [0, 2, 4, 7, 9, 11]:
            student_index = int(teacher_index / 2)
            if name.startswith(f"_transformer.encoder.layer.{student_index}"):
                student_model_weights[name] = teacher_model_weights[
                    name.replace(str(student_index), str(teacher_index))
                ]

    student_model.load_state_dict(student_model_weights)
    return student_model

In [12]:
student_model = transfer_weights(teacher_model, student_model)