In [2]:
! pip install --upgrade torch==1.6.0

Looking in indexes: https://mirror.baidu.com/pypi/simple/
Collecting torch==1.6.0
[?25l  Downloading https://mirror.baidu.com/pypi/packages/5d/5e/35140615fc1f925023f489e71086a9ecc188053d263d3594237281284d82/torch-1.6.0-cp37-cp37m-manylinux1_x86_64.whl (748.8MB)
[K     |████████████████████████████████| 748.8MB 9.1MB/s eta 0:00:011��████████████████████▎  | 686.2MB 8.0MB/s eta 0:00:08
Installing collected packages: torch
  Found existing installation: torch 1.4.0
    Uninstalling torch-1.4.0:
      Successfully uninstalled torch-1.4.0
Successfully installed torch-1.6.0


In [2]:
! pip install torchvision==0.7.0

Looking in indexes: https://mirror.baidu.com/pypi/simple/
Collecting torchvision==0.7.0
[?25l  Downloading https://mirror.baidu.com/pypi/packages/4d/b5/60d5eb61f1880707a5749fea43e0ec76f27dfe69391cdec953ab5da5e676/torchvision-0.7.0-cp37-cp37m-manylinux1_x86_64.whl (5.9MB)
[K     |████████████████████████████████| 5.9MB 13.6MB/s eta 0:00:01
Installing collected packages: torchvision
  Found existing installation: torchvision 0.5.0
    Uninstalling torchvision-0.5.0:
      Successfully uninstalled torchvision-0.5.0
Successfully installed torchvision-0.7.0


In [3]:
import torch
import random
import numpy as np
config = {
        'train_file_path': 'data/data100821/train.json',
        'dev_file_path': 'data/data100821/dev.json',
        'test_file_path': 'data/data100821/test.json',
        'output_path': '.',
        'model_path': 'data/data94445',
        'batch_size': 64,
        'num_epochs': 1,
        'max_seq_len': 64,
        'decay': 0.995,
        'kd_coeff': 1.0,
        'learning_rate': 2e-5,
        'warmup_ratio': 0.05,
        'weight_decay': 0.01,
        'use_bucket': True,
        'bucket_multiplier': 200,
        'device': 'cuda',
        'n_gpus': 0,
        'use_amp': True,  
        'logging_step': 400,
        'ema_start_step': 500,
        'ema_start': False,
        'seed': 2021
    }

if not torch.cuda.is_available():
    config['device'] = 'cpu'
else:
    config['n_gpus'] = torch.cuda.device_count()
    config['batch_size'] *= config['n_gpus']

def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    return seed

seed_everything(config['seed'])

2021

```
from time import sleep
from tqdm import tqdm
for i in tqdm(range(60*15), desc='现在是休息时间，看录播的同学可以跳过哦～'):
    sleep(1)
```

In [4]:
! pip install transformers==4.0.1

Looking in indexes: https://mirror.baidu.com/pypi/simple/


In [5]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained(config['model_path'])

In [6]:
import pandas as pd
import json
def parse_data(path, data_type='train'):
    sentence_a = []
    sentence_b = []
    labels = []
    with open(path, 'r', encoding='utf8') as f:
        for line in tqdm(f.readlines(), desc=f'Reading {data_type} data'):
            line = json.loads(line)
            sentence_a.append(line['sentence1'])
            sentence_b.append(line['sentence2'])
            if data_type != 'test':
                labels.append(int(line['label']))
            else:
                labels.append(0)
    df = pd.DataFrame(zip(sentence_a, sentence_b, labels), columns=['text_a', 'text_b', 'labels'])
    return df

In [7]:
def build_bert_inputs(inputs, label, sentence_a, sentence_b, tokenizer):
    inputs_dict = tokenizer.encode_plus(sentence_a, sentence_b, add_special_tokens=True,
                                        return_token_type_ids=True, return_attention_mask=True)
    inputs['input_ids'].append(inputs_dict['input_ids'])
    inputs['token_type_ids'].append(inputs_dict['token_type_ids'])
    inputs['attention_mask'].append(inputs_dict['attention_mask'])
    inputs['labels'].append(label)

In [8]:
from tqdm import tqdm
from collections import defaultdict
def read_data(config, tokenizer):
    train_df = parse_data(config['train_file_path'], data_type='train')
    dev_df = parse_data(config['dev_file_path'], data_type='dev')
    test_df = parse_data(config['test_file_path'], data_type='test')

    data_df = {'train': train_df, 'dev': dev_df, 'test': test_df}
    processed_data = {}

    for data_type, df in data_df.items():
        inputs = defaultdict(list)
        for i, row in tqdm(df.iterrows(), desc=f'Preprocessing {data_type} data', total=len(df)):
            label = row[2]
            sentence_a, sentence_b = row[0], row[1]
            build_bert_inputs(inputs, label, sentence_a, sentence_b, tokenizer)

        processed_data[data_type] = inputs

    return processed_data


In [9]:
data = read_data(config, tokenizer)

Reading train data: 100%|██████████| 34334/34334 [00:00<00:00, 260985.27it/s]
Reading dev data: 100%|██████████| 4316/4316 [00:00<00:00, 262955.07it/s]
Reading test data: 100%|██████████| 3861/3861 [00:00<00:00, 267752.52it/s]
Preprocessing train data: 100%|██████████| 34334/34334 [00:17<00:00, 1978.58it/s]
Preprocessing dev data: 100%|██████████| 4316/4316 [00:02<00:00, 1972.19it/s]
Preprocessing test data: 100%|██████████| 3861/3861 [00:01<00:00, 1936.16it/s]


In [10]:
class Collator:
    def __init__(self, max_seq_len, tokenizer):
        self.max_seq_len = max_seq_len
        self.tokenizer = tokenizer

    def pad_and_truncate(self, input_ids_list, token_type_ids_list,
                         attention_mask_list, labels_list, max_seq_len):
        input_ids = torch.zeros((len(input_ids_list), max_seq_len), dtype=torch.long)
        token_type_ids = torch.zeros_like(input_ids)
        attention_mask = torch.zeros_like(input_ids)
        for i in range(len(input_ids_list)):
            seq_len = len(input_ids_list[i])
            if seq_len <= max_seq_len:
                input_ids[i, :seq_len] = torch.tensor(input_ids_list[i], dtype=torch.long)
                token_type_ids[i, :seq_len] = torch.tensor(token_type_ids_list[i], dtype=torch.long)
                attention_mask[i, :seq_len] = torch.tensor(attention_mask_list[i], dtype=torch.long)
            else:
                input_ids[i] = torch.tensor(input_ids_list[i][:max_seq_len - 1] + [self.tokenizer.sep_token_id],
                                            dtype=torch.long)
                token_type_ids[i] = torch.tensor(token_type_ids_list[i][:max_seq_len], dtype=torch.long)
                attention_mask[i] = torch.tensor(attention_mask_list[i][:max_seq_len], dtype=torch.long)

        labels = torch.tensor(labels_list, dtype=torch.long)
        return input_ids, token_type_ids, attention_mask, labels

    def __call__(self, examples):
        input_ids_list, token_type_ids_list, attention_mask_list, labels_list = list(zip(*examples))
        cur_max_seq_len = max(len(input_id) for input_id in input_ids_list)
        max_seq_len = min(cur_max_seq_len, self.max_seq_len)

        input_ids, token_type_ids, attention_mask, labels = self.pad_and_truncate(input_ids_list, token_type_ids_list,
                                                                                  attention_mask_list, labels_list,
                                                                                  max_seq_len)

        data_dict = {
            'input_ids': input_ids,
            'token_type_ids': token_type_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

        return data_dict

In [11]:
collate_fn = Collator(config['max_seq_len'], tokenizer)

In [12]:
from torch.utils.data import Dataset
class AFQMCDataset(Dataset):

    def __init__(self, data_dict):
        super(AFQMCDataset, self).__init__()
        self.data_dict = data_dict

    def __getitem__(self, index):
        data = (self.data_dict['input_ids'][index], self.data_dict['token_type_ids'][index],
                self.data_dict['attention_mask'][index], self.data_dict['labels'][index])
        return data

    def __len__(self):
        return len(self.data_dict['input_ids'])

In [13]:
from torch.utils.data import Sampler
class SortedSampler(Sampler):
    """ Samples elements sequentially, always in the same order.

    Args:
        data (iterable): Iterable data.
        sort_key (callable): Specifies a function of one argument that is used to extract a
            numerical comparison key from each list element.

    Example:
        >>> list(SortedSampler(range(10), sort_key=lambda i: -i))
        [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]

    """

    def __init__(self, data, sort_key):
        super().__init__(data)
        self.data = data
        self.sort_key = sort_key
        zip_ = [(i, self.sort_key(row)) for i, row in enumerate(self.data)]
        zip_ = sorted(zip_, key=lambda r: r[1])
        self.sorted_indexes = [item[0] for item in zip_]

    def __iter__(self):
        return iter(self.sorted_indexes)

    def __len__(self):
        return len(self.data)

In [14]:
from torch.utils.data import BatchSampler, SubsetRandomSampler
import math
class BucketBatchSampler(BatchSampler):
    """ `BucketBatchSampler` toggles between `sampler` batches and sorted batches.

    Typically, the `sampler` will be a `RandomSampler` allowing the user to toggle between
    random batches and sorted batches. A larger `bucket_size_multiplier` is more sorted and vice
    versa.

    Background:
        ``BucketBatchSampler`` is similar to a ``BucketIterator`` found in popular libraries like
        ``AllenNLP`` and ``torchtext``. A ``BucketIterator`` pools together examples with a similar
        size length to reduce the padding required for each batch while maintaining some noise
        through bucketing.

        **AllenNLP Implementation:**
        https://github.com/allenai/allennlp/blob/master/allennlp/data/iterators/bucket_iterator.py

        **torchtext Implementation:**
        https://github.com/pytorch/text/blob/master/torchtext/data/iterator.py#L225

    Args:
        sampler (torch.data.utils.sampler.Sampler):
        batch_size (int): Size of mini-batch.
        drop_last (bool): If `True` the sampler will drop the last batch if its size would be less
            than `batch_size`.
        sort_key (callable, optional): Callable to specify a comparison key for sorting.
        bucket_size_multiplier (int, optional): Buckets are of size
            `batch_size * bucket_size_multiplier`.
    """

    def __init__(self,
                 sampler,
                 batch_size,
                 drop_last,
                 sort_key,
                 bucket_size_multiplier=100):
        super().__init__(sampler, batch_size, drop_last)
        self.sort_key = sort_key
        self.bucket_sampler = BatchSampler(sampler,
                                           min(batch_size * bucket_size_multiplier, len(sampler)),
                                           False)

    def __iter__(self):
        for bucket in self.bucket_sampler:
            sorted_sampler = SortedSampler(bucket, self.sort_key)
            for batch in SubsetRandomSampler(
                    list(BatchSampler(sorted_sampler, self.batch_size, self.drop_last))):
                yield [bucket[i] for i in batch]

    def __len__(self):
        if self.drop_last:
            return len(self.sampler) // self.batch_size
        else:
            return math.ceil(len(self.sampler) / self.batch_size)

In [15]:
from torch.utils.data import RandomSampler, DataLoader

def build_dataloader(config, data, collate_fn):
    train_dataset = AFQMCDataset(data['train'])
    dev_dataset = AFQMCDataset(data['dev'])
    test_dataset = AFQMCDataset(data['test'])
    
    if config['use_bucket']:
        train_sampler = RandomSampler(train_dataset)
        bucket_sampler = BucketBatchSampler(train_sampler, batch_size=config['batch_size'],
                                            drop_last=False, sort_key=lambda x: len(train_dataset[x][0]),  # 以 input_id 长度作为排序的指标
                                            bucket_size_multiplier=config['bucket_multiplier'])
        train_dataloader = DataLoader(dataset=train_dataset, batch_sampler=bucket_sampler,
                                      num_workers=4, collate_fn=collate_fn)
    else:
        train_dataloader = DataLoader(dataset=train_dataset, batch_size=config['batch_size'],
                                      shuffle=True, num_workers=4, collate_fn=collate_fn)
    dev_dataloader = DataLoader(dataset=dev_dataset, batch_size=config['batch_size'],
                                shuffle=False, num_workers=4, collate_fn=collate_fn)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=config['batch_size'],
                                 shuffle=False, num_workers=4, collate_fn=collate_fn)
    return train_dataloader, dev_dataloader, test_dataloader

In [16]:
train_dataloader, dev_dataloader, test_dataloader = build_dataloader(config, data, collate_fn)

In [17]:
for i in train_dataloader:
    print(i)
    break

{'input_ids': tensor([[ 101,  955, 1446,  ..., 6820, 3621,  102],
        [ 101, 5709, 1446,  ..., 5709, 1446,  102],
        [ 101, 2769, 4638,  ...,  679,  749,  102],
        ...,
        [ 101, 5709, 1446,  ..., 3621, 1408,  102],
        [ 101, 1555, 2157,  ..., 1555, 2157,  102],
        [ 101, 2769, 3221,  ..., 3082,  868,  102]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        ...,
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'labels': tensor([1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1,
        0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0,
        0, 0, 0

In [18]:
from sklearn import metrics
def evaluation(config, model, val_dataloader):
    model.eval()
    preds = []
    labels = []
    val_loss = 0.
    val_iterator = tqdm(val_dataloader, desc='Evaluation', total=len(val_dataloader))

    with torch.no_grad():
        for batch in val_iterator:
            labels.append(batch['labels'])
            batch_cuda = {item: value.to(config['device']) for item, value in list(batch.items())}
            loss, logits = model(**batch_cuda)[:2]

            if config['n_gpus'] > 1:
                loss = loss.mean()

            val_loss += loss.item()
            preds.append(logits.argmax(dim=-1).detach().cpu())

    avg_val_loss = val_loss / len(val_dataloader)
    labels = torch.cat(labels, dim=0).numpy()
    preds = torch.cat(preds, dim=0).numpy()
    f1 = metrics.f1_score(labels, preds)
    acc = metrics.accuracy_score(labels, preds)
    return avg_val_loss, f1, acc


# 自集成 和 自蒸馏

![自监督](https://img-blog.csdnimg.cn/e0bff606718a480b84b26b37c02c1651.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

![自集成](https://img-blog.csdnimg.cn/e7013e8b74424092be6658ff9ff866bc.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

![self1](https://img-blog.csdnimg.cn/9857a9457bba49a2953f06206a09b2cf.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

self-ensemble 为了进一步降低集成模型的复杂度，从而使用了一种更高效的集成方法，该方法将多个基本模型 与 参数平均相结合，而不是保留多个基本模型。\
使用知识蒸馏来提高微调效率。当前的BERT模型（学生模型），自集成模型（教师模型）。教师模型是 之前几个时间步长 的学生模型的平均值。

![在这里插入图片描述](https://img-blog.csdnimg.cn/9da9862031d64d628b10f1d6b5644729.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

![在这里插入图片描述](https://img-blog.csdnimg.cn/1ae9a6f0c02e43789351952ef79d8f2c.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

![在这里插入图片描述](https://img-blog.csdnimg.cn/2484928631e64646a4d17bc3452bc315.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

自集成模型，在单个训练阶段组合不同时间步长的中间模型。将每个时间步的BERT视为基础模型，并将它们组合成一个自集成模型。

![在这里插入图片描述](https://img-blog.csdnimg.cn/16055adb9e0f4668ac25025956364ef7.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

![在这里插入图片描述](https://img-blog.csdnimg.cn/3027d580da3d443ca50a38f674077bc2.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

![在这里插入图片描述](https://img-blog.csdnimg.cn/5e7a8e4517684d23a9175a61f9a83394.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

# EMA:指数移动平均

### \theta_{t} 取平均之后的参数  \tilde\theta_{t}当前时间步的参数
$$\theta_{t}   = 0.99\theta_{t-1} + (1-0.99)\tilde\theta_{t}$$
$$\theta_{t-1} = 0.99\theta_{t-2} + (1-0.99)\tilde\theta_{t-1}$$
$$\theta_{t-2} = 0.99\theta_{t-3} + (1-0.99)\tilde\theta_{t-2}$$
$$\theta_{t-3} = 0.99*0 + (1-0.99)\tilde\theta_{t-3}$$

$$\theta_{t}=0.99\theta_{t-1} + (1-0.99)\tilde\theta_{t}$$
$$\theta_{t}=0.99\theta_{t-1} + 0.01 * \tilde\theta_{t}$$
$$\theta_{t}=0.01*\tilde\theta_{t}+0.99\theta_{t-1}$$
$$\theta_{t}=0.01*\tilde\theta_{t}+0.99(0.99\theta_{t-2} + 0.01*\tilde\theta_{t-1})$$
$$\theta_{t}=0.01*\tilde\theta_{t}+0.01*0.99*\tilde\theta_{t-1}+0.99*0.99*\theta_{t-2}$$
$$\theta_{t}=0.01*\tilde\theta_{t}+0.01*0.99*\tilde\theta_{t-1}+0.99*0.99*(0.99*\theta_{t-3} + 0.01*\tilde\theta_{t-2})$$
$$\theta_{t}=0.01*\tilde\theta_{t}+0.01*0.99*\tilde\theta_{t-1}+0.01*0.99*0.99*\tilde\theta_{t-2}+0.99*0.99*0.99*\theta_{t-3}$$
$$\theta_{t}=0.01*\tilde\theta_{t}+0.01*0.99*\tilde\theta_{t-1}+0.01*0.99*0.99*\tilde\theta_{t-2}+0.01*0.99*0.99*0.99*\tilde\theta_{t-3})$$
$$\theta_{t}=0.01*0.99^{0}*\tilde\theta_{t}+0.01*0.99^{1}*\tilde\theta_{t-1}+0.01*0.99^{2}*\tilde\theta_{t-2}+0.01*0.99^{3}*\tilde\theta_{t-3})$$

$$\theta_{t} = decay\theta_{t-1} + (1-decay)\tilde\theta_{t}$$

自蒸馏平均：\
计算两个loss CE 和 MSE\
MSE（当前时间步的参数，前几个时间步的参数求平均（不包括当前时间步））

![在这里插入图片描述](https://img-blog.csdnimg.cn/a3c4790799e34d81940a380c102c2a24.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L3dlaXhpbl80MTI4NzA2MA==,size_16,color_FFFFFF,t_70)

In [19]:
class EMA:
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        self.register()
    
    # EMA初始化
    def register(self):  
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                # 把当前模型参数保存到副本中
                self.shadow[name] = param.data.clone()
    
    # 每一步都要更新
    def update(self):  
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                # 更新的参数
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                # 保存平均参数
                self.shadow[name] = new_average.clone()
    
    # 评估的时候用
    def apply_shadow(self): 
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                # 把当前参数备份
                self.backup[name] = param.data
                # 用维护的平均参数 替换当前模型的参数 进行模型评估
                param.data = self.shadow[name]
    
    # 恢复到训练时的参数
    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

$$ \theta_{t} = decay\theta_{t-1} + (1-decay)\tilde\theta_{t}$$
$$ \theta_{t} = decay\theta_{t-1} + \tilde\theta_{t} - decay\tilde\theta_{t}$$
$$ \theta_{t} = decay(\theta_{t-1} - \tilde\theta_{t}) + \tilde\theta_{t} 

```
one_minus_decay = 1.0 - decay
for s_param, param in zip(kd_model.parameters(), parameters):
     # 指数移动平均
     s_param.sub_(one_minus_decay * (s_param - param))
```
s_param：平均参数
param：当前参数
$$\theta_{t} = \theta_{t-1} - (1 - decay) * (\theta_{t-1} - \tilde\theta_{t})$$
$$\theta_{t} = \theta_{t-1} - (\theta_{t-1} - \tilde\theta_{t}) + decay(\theta_{t-1} - \tilde\theta_{t})$$
$$\theta_{t} = \tilde\theta_{t} + decay(\theta_{t-1} - \tilde\theta_{t})$$

```             #\theta_{t-1} param \theta_{t-1}
                for s_param, param in zip(kd_model.parameters(), parameters):
                    s_param.sub_(one_minus_decay * (s_param - param))
```
$$ \theta_{t} = \theta_{t-1} - (1-decay)( \theta_{t-1} - \tilde\theta_{t})$$
$$ \theta_{t} = \theta_{t-1} - ( \theta_{t-1} - \tilde\theta_{t}) + decay( \theta_{t-1} - \tilde\theta_{t})$$
$$ \theta_{t} = \tilde\theta_{t} + decay( \theta_{t-1} - \tilde\theta_{t})$$

In [20]:
from transformers import AdamW, BertForSequenceClassification
from torch.cuda import amp
from extra_file.extra_optim import *
from extra_file.extra_pgd import *
from extra_file.extra_fgm import *
from extra_file.extra_loss import *
from tqdm import trange
import copy, os
def train(config, train_dataloader, dev_dataloader):
    model = BertForSequenceClassification.from_pretrained(config['model_path'])

    param_optimizer = list(model.named_parameters())
    scaler = amp.GradScaler(enabled=config['use_amp'])
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         "weight_decay": config['weight_decay']},
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         "weight_decay": 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=config['learning_rate'],
                      correct_bias=False, eps=1e-8)
    optimizer = Lookahead(optimizer, 5, 1)
    total_steps = config['num_epochs'] * len(train_dataloader)
    lr_scheduler = WarmupLinearSchedule(optimizer,
                                        warmup_steps=int(config['warmup_ratio'] * total_steps),
                                        t_total=total_steps)
    model.to(config['device'])
    epoch_iterator = trange(config['num_epochs'])
    global_steps = 0
    train_loss = 0.
    logging_loss = 0.
    best_acc = 0.
    best_model_path = ''

    if config['n_gpus'] > 1:
        model = nn.DataParallel(model)
    
    # -----------    new----------------#
    # 定义 MSE loss. (x-y)**2
    kd_loss_fct = nn.MSELoss()
    # 复制BERT模型，得到Average
    kd_model = copy.deepcopy(model)
    # Average不需要反向传播
    kd_model.eval()

    for _ in epoch_iterator:

        train_iterator = tqdm(train_dataloader, desc='Training', total=len(train_dataloader))
        model.train()
        for batch in train_iterator:
            batch_cuda = {item: value.to(config['device']) for item, value in list(batch.items())}
            with amp.autocast(enabled=config['use_amp']):
                loss, logits = model(**batch_cuda)[:2]
                if config['n_gpus'] > 1:
                    loss = loss.mean()
                
                # --------------------     new -----------------#
                with torch.no_grad():
                    # SDA 拿到标签
                    kd_logits = kd_model(**batch_cuda)[1]
                # SDA 使用 MSE计算损失
                kd_loss = kd_loss_fct(logits, kd_logits)
                # SDA 加权损失
                loss += config['kd_coeff'] * kd_loss
                # --------------------     new -----------------#

            scaler.scale(loss).backward()

            scaler.step(optimizer)
            scaler.update()
            lr_scheduler.step()
            optimizer.zero_grad()
        
            if config['ema_start']:
                ema.update()

            train_loss += loss.item()
            global_steps += 1
            
            # --------------------     new -----------------#
            # 修正偏差
            decay = min(config['decay'], (1 + global_steps) / (10 + global_steps))

            one_minus_decay = 1.0 - decay
            # SDA 更新Average参数
            with torch.no_grad():
                parameters = [p for p in model.parameters() if p.requires_grad]
                # 指数移动平均
                for s_param, param in zip(kd_model.parameters(), parameters):
                    s_param.sub_(one_minus_decay * (s_param - param))
            # --------------------     new -----------------#

            train_iterator.set_postfix_str(f'running training loss: {loss.item():.4f}')

            if global_steps % config['logging_step'] == 0:
                if global_steps >= config['ema_start_step'] and not config['ema_start']:
                    print('\n>>> EMA starting ...')
                    config['ema_start'] = True
                    # --------------------     new -----------------#
                    # if 多张GPU卡， model 会放在 model.module 属性里， else 返回 model
                    ema = EMA(model.module if hasattr(model, 'module') else model, decay=0.999)
                    # --------------------     new -----------------#
                print_train_loss = (train_loss - logging_loss) / config['logging_step']
                logging_loss = train_loss

                if config['ema_start']:
                    ema.apply_shadow()
                val_loss, f1, acc = evaluation(config, model, dev_dataloader)

                print_log = f'\n>>> training loss: {print_train_loss:.6f}, valid loss: {val_loss:.6f}, '

                if acc > best_acc:
                    model_save_path = os.path.join(config['output_path'],
                                                   f'checkpoint-{global_steps}-{acc:.6f}')
                    model_to_save = model.module if hasattr(model, 'module') else model
                    model_to_save.save_pretrained(model_save_path)
                    best_acc = acc
                    best_model_path = model_save_path

                print_log += f'valid f1: {f1:.6f}, valid acc: {acc:.6f}'

                print(print_log)
                model.train()
                if config['ema_start']:
                    ema.restore()

    return model, best_model_path

In [21]:
train(config, train_dataloader, dev_dataloader)

Some weights of the model checkpoint at data/data94445 were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at da


>>> training loss: 0.607915, valid loss: 0.535067, valid f1: 0.458794, valid acc: 0.704819



Training:  74%|███████▍  | 400/537 [52:56<2:40:59, 70.51s/it, running training loss: 0.5476][A
Training:  75%|███████▍  | 401/537 [52:56<1:57:01, 51.63s/it, running training loss: 0.5476][A
Training:  75%|███████▍  | 401/537 [53:03<1:57:01, 51.63s/it, running training loss: 0.5423][A
Training:  75%|███████▍  | 402/537 [53:03<1:26:02, 38.24s/it, running training loss: 0.5423][A
Training:  75%|███████▍  | 402/537 [53:11<1:26:02, 38.24s/it, running training loss: 0.5135][A
Training:  75%|███████▌  | 403/537 [53:11<1:05:12, 29.20s/it, running training loss: 0.5135][A
Training:  75%|███████▌  | 403/537 [53:18<1:05:12, 29.20s/it, running training loss: 0.6159][A
Training:  75%|███████▌  | 404/537 [53:18<49:45, 22.45s/it, running training loss: 0.6159]  [A
Training:  75%|███████▌  | 404/537 [53:24<49:45, 22.45s/it, running training loss: 0.5871][A
Training:  75%|███████▌  | 405/537 [53:24<38:39, 17.57s/it, running training loss: 0.5871][A
Training:  75%|███████▌  | 405/537 [53:32<3

(BertForSequenceClassification(
   (bert): BertModel(
     (embeddings): BertEmbeddings(
       (word_embeddings): Embedding(21128, 768, padding_idx=1)
       (position_embeddings): Embedding(512, 768)
       (token_type_embeddings): Embedding(2, 768)
       (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
       (dropout): Dropout(p=0.1, inplace=False)
     )
     (encoder): BertEncoder(
       (layer): ModuleList(
         (0): BertLayer(
           (attention): BertAttention(
             (self): BertSelfAttention(
               (query): Linear(in_features=768, out_features=768, bias=True)
               (key): Linear(in_features=768, out_features=768, bias=True)
               (value): Linear(in_features=768, out_features=768, bias=True)
               (dropout): Dropout(p=0.1, inplace=False)
             )
             (output): BertSelfOutput(
               (dense): Linear(in_features=768, out_features=768, bias=True)
               (LayerNorm): LayerNorm((768