In [23]:
import torch
from torch import nn,optim
from torch.utils.data import TensorDataset,DataLoader
import collections
import numpy as np
    
class RNN(nn.Module):
    def __init__(self,vocab,emb,pad,out,hid,num_layers,emb_weights=None,bidirectional=False):
        super().__init__()        
        if emb_weights!=None:
            self.emb=nn.Embedding.from_pretrained(emb_weights,padding_idx=pad)
        else:
            self.emb=nn.Embedding(vocab,emb,padding_idx=pad)
        self.rnn=nn.RNN(emb,hid,bidirectional=bidirectional,batch_first=True)
        self.rnn2=nn.RNN(2*hid,hid,bidirectional=bidirectional,batch_first=True)
        self.rnn3=nn.RNN(2*hid,hid,bidirectional=bidirectional,batch_first=True)
        self.rnn4=nn.RNN(2*hid,hid,bidirectional=bidirectional,batch_first=True)
        self.linear=nn.Linear(2*hid,out)
        self.softmax=torch.nn.Softmax()
        
    def forward(self,x,h=None):
        x=self.emb(x)
        y,h=self.rnn(x,h)
        y,h=self.rnn2(y,h)
        y,h=self.rnn3(y,h)
        y,h=self.rnn4(y,h)
        y=y[:,-1,:]
        y=self.linear(y)
        return y
    
def accuracy(pred,label):
    pred=torch.argmax(pred,dim=-1)
    leng=len(pred)
    return (pred==label).sum().item()/leng

def list2tensor(data,max_len,pad):
    new_list=[]
    for d in data:
        if len(d)>max_len:
            d=d[:max_len]
        else:
            d+=[pad]*(max_len-len(d))
        new_list.append(d)
    return torch.tensor(new_list)

def count(file):
    all_word=[]
    with open(file,"r",encoding="utf-8") as tf:
        tags=[]
        tag_num={"b":0,"t":1,"e":2,"m":3}
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            tags.append(tag_num[tag])
            words=sen.strip().split(" ")
            for word in words:
                all_word.append(word)
        freq=collections.Counter(all_word)
        freq_sort=sorted(freq.items(),key=lambda x:x[1],reverse=True)
        #print(freq_sort)
        word_number={w:i+1 for i,(w,f) in enumerate(freq_sort) if f>=2}
        return word_number,tags

def id_v(file,word_id):
    with open(file,"r",encoding="utf-8") as tf:
        ids=[]
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            words=sen.strip().split(" ")
            sent_ids=[]
            for word in words:
                if word in word_id:
                    sent_ids.append(word_id[word])
                else:
                    sent_ids.append(0)
            ids.append(sent_ids)
        return ids
    
word_id,Y_train=count("train.feature.txt")
word_id_valid,Y_valid=count("valid.feature.txt")
max_len=10
vocab=len(word_id)+1
emb=300
pad=len(word_id)
hid=50
out=4
num_layers=2

X_train=id_v("train.feature.txt",word_id)
X_train=list2tensor(X_train,max_len,pad)
Y_train=torch.tensor(Y_train)

X_valid=id_v("valid.feature.txt",word_id)
X_valid=list2tensor(X_valid,max_len,pad)
Y_valid=torch.tensor(Y_valid)

model=RNN(vocab,emb,pad,out,hid,num_layers,emb_weights=None,bidirectional=True)

dataset_train=TensorDataset(X_train,Y_train)

loader=DataLoader(dataset_train,batch_size=1024)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(), lr=0.1)
epoch=10
for num in range(epoch):
    for X,Y in loader:
        Y_pred=model(X)
        loss=loss_fn(Y_pred,Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        Y_pred=model(X_train)
        loss=loss_fn(Y_pred,Y_train)
        ac=accuracy(Y_pred,Y_train)
        
        Y_pred_v=model(X_valid)
        loss_v=loss_fn(Y_pred_v,Y_valid)
        ac_v=accuracy(Y_pred_v,Y_valid)
        print(f"epoch:{num}\ttrain:\tloss:{loss}\taccuracy:{ac}\tvalid:\tloss:{loss_v}\taccuracy:{ac_v}")

epoch:0	train:	loss:1.128151297569275	accuracy:0.5059880239520959	valid:	loss:1.12944495677948	accuracy:0.5065518532384875
epoch:1	train:	loss:1.0894862413406372	accuracy:0.5276946107784432	valid:	loss:1.09335458278656	accuracy:0.5406214900786223
epoch:2	train:	loss:1.0544825792312622	accuracy:0.5464071856287425	valid:	loss:1.0604034662246704	accuracy:0.5641145638337701
epoch:3	train:	loss:1.0180913209915161	accuracy:0.5808383233532934	valid:	loss:1.0248441696166992	accuracy:0.5935043055035567
epoch:4	train:	loss:0.9795293211936951	accuracy:0.6130239520958084	valid:	loss:0.9857911467552185	accuracy:0.6198053163609135
epoch:5	train:	loss:0.9391775131225586	accuracy:0.6399700598802395	valid:	loss:0.9439021944999695	accuracy:0.6427368026956196
epoch:6	train:	loss:0.8979808688163757	accuracy:0.6616766467065869	valid:	loss:0.900333046913147	accuracy:0.6686634219393486
epoch:7	train:	loss:0.8579407930374146	accuracy:0.6923652694610778	valid:	loss:0.8572514057159424	accuracy:0.687195806813927