In [None]:
import os
from google.colab import drive
drive.mount('/content/drive')

path = "/content/drive/My Drive/Spam-classification-master/"

os.chdir(path)
os.listdir(path)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


['SMSSpamCollection.txt',
 '垃圾邮件分类.ipynb',
 'vocab.txt',
 'bert-large-uncased',
 'bert-base-uncased',
 'SMSSpamCollection.csv']

# Data

In [None]:
# ham：非垃圾短信
# spam：垃圾短信
# \t键后面是短信的正文

# 2.导入要用的包
import pandas as pd 
from sklearn import linear_model
from sklearn.feature_extraction.text import TfidfVectorizer # sklearn包中，特殊提取中的文本模块中，特殊字符向量化方法

# 3.读入数据集
path = './'
filename = 'SMSSpamCollection.txt'
df = pd.read_csv(path + filename, delimiter='\t', header=None)# 用\t分割，没有文件头
# 生成label和x输入
y,X_train = df[0],df[1]

In [None]:
!pip install pytorch_pretrained_bert

# Model

In [3]:
import random
import re
from math import sqrt as msqrt

import torch
import torch.functional as F
from torch import nn
from torch.optim import Adadelta
from torch.utils.data import DataLoader, Dataset
#from pytorch_pretrained_bert import BertModel, BertTokenizer
import numpy as np


In [None]:
max_len = 50
max_vocab = 30522
max_pred = 5

d_k = d_v = 64
d_model = 768  # n_heads * d_k
d_ff = d_model * 4

n_heads = 12
n_layers = 12
n_segs = 2

p_dropout = .1
# BERT propability defined
p_mask = .8
p_replace = .1
p_do_nothing = 1 - p_mask - p_replace

# adapter
hidden_size=64
init_scale=1e-3

device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)

In [None]:
def get_pad_mask(tokens, pad_idx=0):
    '''
    suppose index of [PAD] is zero in word2idx
    tokens: [batch, seq_len]
    '''
    batch, seq_len = tokens.size()
    pad_mask = tokens.data.eq(pad_idx).unsqueeze(1)
    pad_mask = pad_mask.expand(batch, seq_len, seq_len)
    return pad_mask

class Embeddings(nn.Module):
    """
    Word Embedding, Position Embedding, Segment Embedding
    """
    def __init__(self):
        super(Embeddings, self).__init__()
        self.seg_emb = nn.Embedding(n_segs, d_model)
        self.word_emb = nn.Embedding(max_vocab, d_model)

        self.pos_emb = nn.Embedding(max_len, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(p_dropout)

    def forward(self, x, seg):
        '''
        x: [batch, seq_len]
        '''
        

        # positional embedding
        pos = torch.arange(x.shape[1], dtype=torch.long, device=device)
        pos = pos.unsqueeze(0).expand_as(x)
        pos_enc = self.pos_emb(pos)
        
        word_enc = self.word_emb(x)
        
        seg_enc = self.seg_emb(seg)
        x = self.norm(word_enc + pos_enc + seg_enc)
        return self.dropout(x)
        # return: [batch, seq_len, d_model]

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        scores = torch.matmul(Q, K.transpose(-1, -2) / msqrt(d_k))
        # scores: [batch, n_heads, seq_len, seq_len]
  #      print(scores.shape)
        print(attn_mask.shape)
        
        #要先用mask替换掉
        scores.masked_fill_(attn_mask, -1e9)
        attn = nn.Softmax(dim=-1)(scores)
        # context: [batch, n_heads, seq_len, d_v]
        context = torch.matmul(attn, V)
        return context


class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_K = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.W_V = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)

    def forward(self, Q, K, V, attn_mask):
        '''
        Q, K, V: [batch, seq_len, d_model]
        attn_mask: [batch, seq_len, seq_len]
        '''
        batch = Q.size(0)
        '''
        split Q, K, V to per head formula: [batch, seq_len, n_heads, d_k]
        Convenient for matrix multiply opearation later
        q, k, v: [batch, n_heads, seq_len, d_k / d_v]
        '''
        per_Q = self.W_Q(Q).view(batch, -1, n_heads, d_k).transpose(1, 2)
        per_K = self.W_K(K).view(batch, -1, n_heads, d_k).transpose(1, 2)
        per_V = self.W_V(V).view(batch, -1, n_heads, d_v).transpose(1, 2)
   #     print('hi')
   #     print(attn_mask.shape)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
   #     print(attn_mask.shape)
        # context: [batch, n_heads, seq_len, d_v]
        context = ScaledDotProductAttention()(per_Q, per_K, per_V, attn_mask)
        context = context.transpose(1, 2).contiguous().view(
            batch, -1, n_heads * d_v)

        # output: [batch, seq_len, d_model]
        output = self.fc(context)
        return output


def gelu(x):
    '''
    激活函数
    Two way to implements GELU:
    0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
    or
    0.5 * x * (1. + torch.erf(torch.sqrt(x, 2)))
    '''
    return .5 * x * (1. + torch.erf(x / msqrt(2.)))


class FeedForwardNetwork(nn.Module):
    def __init__(self):
        super(FeedForwardNetwork, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p_dropout)
        self.gelu = gelu

    def forward(self, x):
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x



def truncated_normal_(tensor,mean=0,std=init_scale):
    with torch.no_grad():
        size = tensor.shape
        tmp = tensor.new_empty(size+(4,)).normal_()
        valid = (tmp < 2) & (tmp > -2)
        ind = valid.max(-1, keepdim=True)[1]
        tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
        tensor.data.mul_(std).add_(mean)
        return tensor


def feedforward_adapter(input_tensor):
    in_size = input_tensor.shape[1]
    w_1 = nn.Parameter(torch.Tensor(in_size,hidden_size*n_heads))
    print("123")
    print(w_1.shape)
    w_1 = truncated_normal_(w_1,mean=0,std=init_scale)
    print(w_1.shape)
    b_1 = nn.Parameter(torch.Tensor(1,hidden_size))
    print(b_1.shape)
    net = torch.tensordot(input_tensor, w_1, [[1], [0]])
    net= net + b_1  #指定前面维度1和后面维度0做内积。
    net = gelu(net)

    w_2 = nn.Parameter(torch.Tensor(hidden_size*n_heads,in_size))
    w_2 = truncated_normal_(w_2,mean=0,std=init_scale)
    b_2 = nn.Parameter(torch.Tensor(1,in_size))
    net = torch.tensordot(net, w_2, [[1], [0]]) + b_2

        #残差链接
    return net + input_tensor

class Adapter(nn.Module):
    def __init__(self):
        super(Adapter,self).__init__()
        self.a_1 = nn.Linear(768,d_model, bias=False)
        self.a_2 = nn.Linear(d_model,768, bias=False)
        self.gelu = gelu
    def forward(self,input_tensor):
        print("4")
        print(input_tensor.shape)
        net=self.a_1(input_tensor)
        print(net.shape)
        net = self.gelu(net)
        net=self.a_2(net)
        return net + input_tensor


class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        self.enc_attn = MultiHeadAttention()
        self.adapter1 = Adapter()
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = FeedForwardNetwork()
        self.adapter2 = Adapter()
        self.norm2 = nn.LayerNorm(d_model)


    def forward(self, x, pad_mask):
        '''
        pre-norm
        see more detail in https://openreview.net/pdf?id=B1x8anVFPr

        x: [batch, seq_len, d_model]
        '''
        residual = x
    #    print(x.shape)
    #    print("12345")
    #    print(pad_mask.shape)
        x = self.enc_attn(x, x, x, pad_mask)
        print('hi')
        print(x.shape)
        x = self.adapter1(x)
        print(residual.shape)
        x = x+residual
        x = self.norm1(x)
        residual = x
        
        x=self.adapter2(x)
        x = self.ffn(x)+ residual
        x = self.norm2(x)
        return x


class Pooler(nn.Module):
    def __init__(self):
        super(Pooler, self).__init__()
        self.fc = nn.Linear(d_model, d_model)
        self.tanh = nn.Tanh()

    def forward(self, x):
        '''
        x: [batch, d_model] (first place output)
        '''
        x = self.fc(x)
        x = self.tanh(x)
        return x

In [None]:
class BERT(nn.Module):
    def __init__(self, n_layers):
        super(BERT, self).__init__()
        self.embedding = Embeddings()
        self.encoders = nn.ModuleList([
            EncoderLayer() for _ in range(n_layers)
        ])
        self.pooler = Pooler()
        self.gelu = gelu
        self.classify = nn.Linear(d_model, 2)

    def forward(self, tokens, segments,mask):
        output = self.embedding(tokens, segments)
        enc_self_pad_mask = get_pad_mask(tokens)
    #    print(enc_self_pad_mask)
        for layer in self.encoders:
            output = layer(output, enc_self_pad_mask)
        # output: [batch, max_len, d_model]

        # NSP Task
        hidden_pool = self.pooler(output[:, 0])
        logits_cls = self.classify(hidden_pool)

        return logits_cls

In [None]:
!pip install torchinfo

Collecting torchinfo
  Downloading torchinfo-1.6.5-py3-none-any.whl (21 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.6.5


In [None]:
adaptermodel = BERT(n_layers)
from torchinfo import summary
#summary(adaptermodel)

In [None]:
!ls

bert-base-uncased   SMSSpamCollection.csv  vocab.txt
bert-large-uncased  SMSSpamCollection.txt  垃圾邮件分类.ipynb


In [None]:
from pytorch_pretrained_bert import BertModel, BertTokenizer
import numpy as np
import torch

# 加载bert模型，这个路径文件夹下有bert_config.json配置文件和model.bin模型权重文件
bert = BertModel.from_pretrained('bert-base-uncased')


100%|██████████| 407873900/407873900 [00:09<00:00, 41254557.61B/s]


In [None]:
pretrained_dict = bert.state_dict()


In [None]:
model_dict = adaptermodel.state_dict()

In [None]:
pretrained_dict =  {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
adaptermodel.load_state_dict(model_dict)

<All keys matched successfully>

In [None]:
summary(adaptermodel)

Layer (type:depth-idx)                   Param #
BERT                                     --
├─Embeddings: 1-1                        --
│    └─Embedding: 2-1                    1,536
│    └─Embedding: 2-2                    23,440,896
│    └─Embedding: 2-3                    38,400
│    └─LayerNorm: 2-4                    1,536
│    └─Dropout: 2-5                      --
├─ModuleList: 1-2                        --
│    └─EncoderLayer: 2-6                 --
│    │    └─MultiHeadAttention: 3-1      2,359,296
│    │    └─Adapter: 3-2                 1,179,648
│    │    └─LayerNorm: 3-3               1,536
│    │    └─FeedForwardNetwork: 3-4      4,722,432
│    │    └─Adapter: 3-5                 1,179,648
│    │    └─LayerNorm: 3-6               1,536
│    └─EncoderLayer: 2-7                 --
│    │    └─MultiHeadAttention: 3-7      2,359,296
│    │    └─Adapter: 3-8                 1,179,648
│    │    └─LayerNorm: 3-9               1,536
│    │    └─FeedForwardNetwork: 3-10     4,722

# Tokenizer


In [None]:
# Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.

""" Tokenization classes (It's exactly the same code as Google BERT code """

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import unicodedata
import six


def convert_to_unicode(text):
    """Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))

 #   elif six.PY2:
 #       if isinstance(text, str):
 #           return text.decode("utf-8", "ignore")
 #       elif isinstance(text, unicode):
 #           return text
 #       else:
 #           raise ValueError("Unsupported string type: %s" % (type(text)))

    else:
        raise ValueError("Not running on Python2 or Python 3?")


def printable_text(text):
    """Returns text encoded in a way suitable for print or `tf.logging`."""

    # These functions want `str` for both Python2 and Python3, but in one case
    # it's a Unicode string and in the other it's a byte string.
    if six.PY3:
        if isinstance(text, str):
            return text
        elif isinstance(text, bytes):
            return text.decode("utf-8", "ignore")
        else:
            raise ValueError("Unsupported string type: %s" % (type(text)))
#    elif six.PY2:
#        if isinstance(text, str):
#            return text
#        elif isinstance(text, unicode):
#            return text.encode("utf-8")
#        else:
#            raise ValueError("Unsupported string type: %s" % (type(text)))
    else:
        raise ValueError("Not running on Python2 or Python 3?")


def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()
    index = 0
    with open(vocab_file, "r") as reader:
        while True:
            token = convert_to_unicode(reader.readline())
            if not token:
                break
            token = token.strip()
            vocab[token] = index
            index += 1
    return vocab


def convert_tokens_to_ids(vocab, tokens):
    """Converts a sequence of tokens into ids using the vocab."""
    ids = []
    for token in tokens:
        ids.append(vocab[token])
    return ids


def whitespace_tokenize(text):
    """Runs basic whitespace cleaning and splitting on a peice of text."""
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens


class FullTokenizer(object):
    """Runs end-to-end tokenziation."""

    def __init__(self, vocab_file, do_lower_case=True):
        self.vocab = load_vocab(vocab_file)
        self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case)
        self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)

    def tokenize(self, text):
        split_tokens = []
        for token in self.basic_tokenizer.tokenize(text):
            for sub_token in self.wordpiece_tokenizer.tokenize(token):
                split_tokens.append(sub_token)

        return split_tokens

    def convert_tokens_to_ids(self, tokens):
        return convert_tokens_to_ids(self.vocab, tokens)

    def convert_to_unicode(self, text):
        return convert_to_unicode(text)



class BasicTokenizer(object):
    """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""

    def __init__(self, do_lower_case=True):
        """Constructs a BasicTokenizer.
        Args:
          do_lower_case: Whether to lower case the input.
        """
        self.do_lower_case = do_lower_case

    def tokenize(self, text):
        """Tokenizes a piece of text."""
        text = convert_to_unicode(text)
        text = self._clean_text(text)
        orig_tokens = whitespace_tokenize(text)
        split_tokens = []
        for token in orig_tokens:
            if self.do_lower_case:
                token = token.lower()
                token = self._run_strip_accents(token)
            split_tokens.extend(self._run_split_on_punc(token))

        output_tokens = whitespace_tokenize(" ".join(split_tokens))
        return output_tokens

    def _run_strip_accents(self, text):
        """Strips accents from a piece of text."""
        text = unicodedata.normalize("NFD", text)
        output = []
        for char in text:
            cat = unicodedata.category(char)
            if cat == "Mn":
                continue
            output.append(char)
        return "".join(output)

    def _run_split_on_punc(self, text):
        """Splits punctuation on a piece of text."""
        chars = list(text)
        i = 0
        start_new_word = True
        output = []
        while i < len(chars):
            char = chars[i]
            if _is_punctuation(char):
                output.append([char])
                start_new_word = True
            else:
                if start_new_word:
                    output.append([])
                start_new_word = False
                output[-1].append(char)
            i += 1

        return ["".join(x) for x in output]

    def _clean_text(self, text):
        """Performs invalid character removal and whitespace cleanup on text."""
        output = []
        for char in text:
            cp = ord(char)
            if cp == 0 or cp == 0xfffd or _is_control(char):
                continue
            if _is_whitespace(char):
                output.append(" ")
            else:
                output.append(char)
        return "".join(output)


class WordpieceTokenizer(object):
    """Runs WordPiece tokenization."""

    def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
        self.vocab = vocab
        self.unk_token = unk_token
        self.max_input_chars_per_word = max_input_chars_per_word

    def tokenize(self, text):
        """Tokenizes a piece of text into its word pieces.
        This uses a greedy longest-match-first algorithm to perform tokenization
        using the given vocabulary.
        For example:
          input = "unaffable"
          output = ["un", "##aff", "##able"]
        Args:
          text: A single token or whitespace separated tokens. This should have
            already been passed through `BasicTokenizer.
        Returns:
          A list of wordpiece tokens.
        """

        text = convert_to_unicode(text)

        output_tokens = []
        for token in whitespace_tokenize(text):
            chars = list(token)
            if len(chars) > self.max_input_chars_per_word:
                output_tokens.append(self.unk_token)
                continue

            is_bad = False
            start = 0
            sub_tokens = []
            while start < len(chars):
                end = len(chars)
                cur_substr = None
                while start < end:
                    substr = "".join(chars[start:end])
                    if start > 0:
                        substr = "##" + substr
                    if substr in self.vocab:
                        cur_substr = substr
                        break
                    end -= 1
                if cur_substr is None:
                    is_bad = True
                    break
                sub_tokens.append(cur_substr)
                start = end

            if is_bad:
                output_tokens.append(self.unk_token)
            else:
                output_tokens.extend(sub_tokens)
        return output_tokens


def _is_whitespace(char):
    """Checks whether `chars` is a whitespace character."""
    # \t, \n, and \r are technically contorl characters but we treat them
    # as whitespace since they are generally considered as such.
    if char == " " or char == "\t" or char == "\n" or char == "\r":
        return True
    cat = unicodedata.category(char)
    if cat == "Zs":
        return True
    return False


def _is_control(char):
    """Checks whether `chars` is a control character."""
    # These are technically control characters but we count them as whitespace
    # characters.
    if char == "\t" or char == "\n" or char == "\r":
        return False
    cat = unicodedata.category(char)
    if cat.startswith("C"):
        return True
    return False


def _is_punctuation(char):
    """Checks whether `chars` is a punctuation character."""
    cp = ord(char)
    # We treat all non-letter/number ASCII as punctuation.
    # Characters such as "^", "$", and "`" are not in the Unicode
    # Punctuation class but we treat them as punctuation anyways, for
    # consistency.
    if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
            (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
        return True
    cat = unicodedata.category(char)
    if cat.startswith("P"):
        return True
    return False

In [None]:
!pip install fire

Collecting fire
  Downloading fire-0.4.0.tar.gz (87 kB)
[?25l[K     |███▊                            | 10 kB 19.1 MB/s eta 0:00:01[K     |███████▌                        | 20 kB 7.6 MB/s eta 0:00:01[K     |███████████▏                    | 30 kB 4.7 MB/s eta 0:00:01[K     |███████████████                 | 40 kB 4.4 MB/s eta 0:00:01[K     |██████████████████▊             | 51 kB 3.7 MB/s eta 0:00:01[K     |██████████████████████▍         | 61 kB 4.3 MB/s eta 0:00:01[K     |██████████████████████████▏     | 71 kB 4.3 MB/s eta 0:00:01[K     |██████████████████████████████  | 81 kB 4.8 MB/s eta 0:00:01[K     |████████████████████████████████| 87 kB 2.9 MB/s 
Building wheels for collected packages: fire
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.4.0-py2.py3-none-any.whl size=115942 sha256=ab1d977bcbde3ed71c5cdd2febf231e539c65d844b3e53a3b45f883d658da6bc
  Stored in directory: /root/.cache/pip/wheels/8a/67/fb/2e8a12f

In [None]:
# Prepare dataset


# Dataset

In [None]:
tokenizer = FullTokenizer(vocab_file='vocab.txt', do_lower_case=True)

In [None]:
class Pipeline():
    """ Preprocess Pipeline Class : callable """
    def __init__(self):
        super().__init__()

    def __call__(self, instance):
        raise NotImplementedError


class Tokenizing(Pipeline):
    """ Tokenizing sentence pair """
    def __init__(self, preprocessor, tokenize):
        super().__init__()
        self.preprocessor = preprocessor # e.g. text normalization
        self.tokenize = tokenize # tokenize function

    def __call__(self, instance):
        label, text_a = instance

        label = self.preprocessor(label)
        tokens_a = self.tokenize(self.preprocessor(text_a))

        return (label, tokens_a)

In [None]:
!pip install txt

Collecting txt
  Downloading txt-2020.11.12-py3-none-any.whl (3.8 kB)
Installing collected packages: txt
Successfully installed txt-2020.11.12


In [None]:
import csv
class CsvDataset(Dataset):
    """ Dataset Class for CSV file """
    labels = None
    def __init__(self, file, pipeline=[]): # cvs file and pipeline object
        Dataset.__init__(self)
        data = []
        with open(file, "r") as f:
            # list of splitted lines : line is also list
            lines = csv.reader(f, delimiter=',')
            for instance in self.get_instances(lines): # instance : tuple of fields
                for proc in pipeline: # a bunch of pre-processing
                    instance = proc(instance)
                #print(len(instance[0]))
                data.append(instance)

        

        # To Tensors
        self.tensors = [torch.tensor(x, dtype=torch.long) for x in zip(*data)]

    def __len__(self):
        return self.tensors[0].size(0)

    def __getitem__(self, index):
        return tuple(tensor[index] for tensor in self.tensors)

    def get_instances(self, lines):
        """ get instance array from (csv-separated) line list """
        raise 

In [None]:
import itertools

In [None]:
class MRPC(CsvDataset):
    """ Dataset class for MRPC """
    labels = ("ham", "spam") # label names
    def __init__(self, file, pipeline=[]):
        super().__init__(file, pipeline)

    def get_instances(self, lines):
        for line in lines: # skip header
            yield line[0], line[1]   # label, text_a, text_b

In [None]:
def dataset_class(task):
    """ Mapping from task string to Dataset Class """
    table = {'mrpc': MRPC}
    return table[task]

In [None]:
class AddSpecialTokensWithTruncation(Pipeline):
    """ Add special tokens [CLS], [SEP] with truncation """
    def __init__(self, max_len=50):
        super().__init__()
        self.max_len = max_len

    def __call__(self, instance):
        label, tokens_a= instance

        # -3 special tokens for [CLS] text_a [SEP] text_b [SEP]
        # -2 special tokens for [CLS] text_a [SEP]
        _max_len = self.max_len - 2
        # Add Special Tokens
        if len(tokens_a)>_max_len:
          tokens_a=tokens_a[:_max_len]
        tokens_a = ['[CLS]'] + tokens_a + ['[SEP]']
        #print(len(tokens_a))
        return (label, tokens_a)

In [None]:
class TokenIndexing(Pipeline):
    """ Convert tokens into token indexes and do zero-padding """
    def __init__(self, indexer, labels, max_len=512):
        super().__init__()
        self.indexer = indexer # function : tokens to indexes
        # map from a label name to a label index
        self.label_map = {name: i for i, name in enumerate(labels)}
        self.max_len = max_len

    def __call__(self, instance):
        label, tokens_a= instance

        input_ids = self.indexer(tokens_a )
        segment_ids = [0]*len(tokens_a) # token type ids
        input_mask = [1]*len(tokens_a)

        label_id = self.label_map[label]

      #  print(self.max_len)
        # zero padding
        n_pad = self.max_len - len(input_ids)
        input_ids.extend([0]*n_pad)
        segment_ids.extend([0]*n_pad)
        input_mask.extend([0]*n_pad)
      #  print(len(input_ids))
      #  print(len(segment_ids))
      #  print(len(input_mask))

        return (input_ids, segment_ids, input_mask, label_id)

In [None]:
task='mrpc'
data_file='SMSSpamCollection.csv'

In [None]:
max_len=50
TaskDataset = dataset_class(task)

pipeline = [Tokenizing(tokenizer.convert_to_unicode, tokenizer.tokenize),
            AddSpecialTokensWithTruncation(max_len),
            TokenIndexing(tokenizer.convert_tokens_to_ids,TaskDataset.labels, max_len)]
dataset = TaskDataset(data_file, pipeline)

In [None]:
from torch.utils.data import Dataset, DataLoader

In [None]:
data_iter = DataLoader(dataset, batch_size=20, shuffle=True)

# Classifier

In [None]:
!pip install loralib

Collecting loralib
  Downloading loralib-0.1.1-py3-none-any.whl (8.8 kB)
Installing collected packages: loralib
Successfully installed loralib-0.1.1


In [None]:
model = adaptermodel
lr = 1e-3
epochs = 500
criterion = nn.CrossEntropyLoss()

# 冻结fc1层的参数
for name, param in model.named_parameters():
  if "adapter" in name:
    param.requires_grad = False


optimizer = Adadelta(filter(lambda p : p.requires_grad, model.parameters()), lr=lr)
model.to(device)





# training
for epoch in range(epochs):
    for one_batch in data_iter:
        input_ids, segment_ids, masked_tokens,is_next = [ele.to(device) for ele in one_batch]
        print(input_ids.shape)
        
        logits_cls= model(input_ids, segment_ids,masked_tokens)
        loss_cls = criterion(logits_cls, is_next)
        loss = loss_cls
        if (epoch + 1) % 10 == 0:
            print(f'Epoch:{epoch + 1} \t loss: {loss:.6f}')

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()