In [17]:
import torch
from torch import nn,optim
from torch.utils.data import TensorDataset,DataLoader
import collections
import numpy as np
from transformers import *
import pandas as pd

tokenizer_class=BertTokenizer
tokenizer = tokenizer_class.from_pretrained('bert-base-uncased')
n_unit=768

def dfs_freeze(model):
    for name,child in model.named_children():
        for param in child.parameters():
            param.requires_grad=False
        dfs_freeze(child)
        
class BertClassifier(nn.Module):
    def __init__(self,n_classes=4):
        super(BertClassifier,self).__init__()
        self.bert_model = BertModel.from_pretrained('bert-base-uncased') 
        self.fc=nn.Linear(n_unit, n_classes)

    def forward(self,ids):
        seg_ids = torch.zeros_like(ids) 
        attention_mask = (ids > 0)
        last= self.bert_model(input_ids=ids, token_type_ids=seg_ids, attention_mask=attention_mask)
        x=last["last_hidden_state"][:,0,:]
        logit = self.fc(x.view(-1,n_unit))
        return logit

def list2tensor(data,max_len,pad):
    new_list=[]
    for d in data:
        if len(d)>max_len:
            d=d[:max_len]
        else:
            d+=[pad]*(max_len-len(d))
        new_list.append(d)
    return torch.tensor(new_list,dtype=torch.int64)

def accuracy(pred,label):
    pred=torch.argmax(pred,dim=-1)
    leng=len(pred)
    return (pred==label).sum().item()/leng

def count(file):
    all_word=[]
    with open(file,"r",encoding="utf-8") as tf:
        tags=[]
        tag_num={"b":0,"t":1,"e":2,"m":3}
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            tags.append(tag_num[tag])
            words=sen.strip().split(" ")
            for word in words:
                all_word.append(word)
        freq=collections.Counter(all_word)
        freq_sort=sorted(freq.items(),key=lambda x:x[1],reverse=True)
        #print(freq_sort)
        word_number={w:i+1 for i,(w,f) in enumerate(freq_sort) if f>=2}
        return word_number,tags

def id_v(file,word_id):
    with open(file,"r",encoding="utf-8") as tf:
        ids=[]
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            words=sen.strip().split(" ")
            sent_ids=[]
            for word in words:
                if word in word_id:
                    sent_ids.append(word_id[word])
                else:
                    sent_ids.append(0)
            ids.append(sent_ids)
        return ids
    
def df2id(df):
    tokenized = df[1].apply((lambda x: tokenizer.encode(x, add_special_tokens=True)))
    return tokenized
    
word_id,Y_train=count("train.feature.txt")
word_id_valid,Y_valid=count("valid.feature.txt")
max_len=15
pad=0

train=pd.read_csv('train.feature.txt', header=None, sep='\t')
X_train=df2id(train)
X_train=list2tensor(X_train,max_len,pad)
Y_train=torch.tensor(Y_train,dtype=torch.int64)



model=BertClassifier()
dfs_freeze(model)
model.fc.requires_grad_(True)

dataset_train=TensorDataset(X_train,Y_train)
loader=DataLoader(dataset_train,batch_size=1024)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.AdamW(model.parameters(),lr=0.01)

epoch=10
for num in range(epoch):
    for X,Y in loader:
        Y_pred=model(X)
        loss=loss_fn(Y_pred,Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        Y_pred=model(X_train)
        loss=loss_fn(Y_pred,Y_train)
        ac=accuracy(Y_pred,Y_train)
        
        print(f"epoch:{num}\ttrain:\tloss:{loss}\taccuracy:{ac}")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


epoch:0	train:	loss:0.5298143625259399	accuracy:0.8131785847997005
epoch:1	train:	loss:0.42633259296417236	accuracy:0.845376263571696


KeyboardInterrupt: 