In [3]:
import torch
from torch import nn,optim
from torch.utils.data import TensorDataset,DataLoader
import collections

class RNN(nn.Module):
    def __init__(self,vocab,emb,pad,out,hid):
        super().__init__()
        self.hidden_size=hid
        self.emb=nn.Embedding(vocab,emb,padding_idx=pad)
        self.rnn=nn.RNN(emb,hid,batch_first=True)
        self.linear=nn.Linear(hid,out)
        self.softmax=nn.Softmax()
        
    def forward(self,x,h=None):
        x=self.emb(x)
        y,h=self.rnn(x,h)
        y=y[:,-1,:]
        y=self.linear(y)
        return y
    
def accuracy(pred,label):
    pred=torch.argmax(pred,dim=-1)
    leng=len(pred)
    return (pred==label).sum().item()/leng

def list2tensor(data,max_len,pad):
    new_list=[]
    for d in data:
        if len(d)>max_len:
            d=d[:max_len]
        else:
            d+=[pad]*(max_len-len(d))
        new_list.append(d)
    return torch.tensor(new_list)

def count(file):
    all_word=[]
    with open(file,"r",encoding="utf-8") as tf:
        tags=[]
        tag_num={"b":0,"t":1,"e":2,"m":3}
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            tags.append(tag_num[tag])
            words=sen.strip().split(" ")
            for word in words:
                all_word.append(word)
        freq=collections.Counter(all_word)
        freq_sort=sorted(freq.items(),key=lambda x:x[1],reverse=True)
        #print(freq_sort)
        word_number={w:i+1 for i,(w,f) in enumerate(freq_sort) if f>=2}
        return word_number,tags

def id_v(file,word_id):
    with open(file,"r",encoding="utf-8") as tf:
        ids=[]
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            words=sen.strip().split(" ")
            sent_ids=[]
            for word in words:
                if word in word_id:
                    sent_ids.append(word_id[word])
                else:
                    sent_ids.append(0)
            ids.append(sent_ids)
        return ids    

word_id,Y_train=count("train.feature.txt")
word_id_valid,Y_valid=count("valid.feature.txt")
max_len=10
vocab=len(word_id)+1
emb=300
pad=len(word_id)
hid=50
out=4

X_train=id_v("train.feature.txt",word_id)
X_train=list2tensor(X_train,max_len,pad)
Y_train=torch.tensor(Y_train)

X_valid=id_v("valid.feature.txt",word_id)
X_valid=list2tensor(X_valid,max_len,pad)
Y_valid=torch.tensor(Y_valid)

model=RNN(vocab,emb,pad,out,hid)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=model.to(device)
dataset_train=TensorDataset(X_train.to(device),Y_train.to(device))

B=1024
loader=DataLoader(dataset_train,batch_size=B)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(), lr=0.01)
epoch=10
for num in range(epoch):
    for X,Y in loader:
        Y_pred=model(X)
        loss=loss_fn(Y_pred,Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        Y_pred=model(X_train.to(device))
        loss=loss_fn(Y_pred,Y_train.to(device))
        ac=accuracy(Y_pred,Y_train)
        
        Y_pred_v=model(X_valid.to(device))
        loss_v=loss_fn(Y_pred_v,Y_valid.to(device))
        ac_v=accuracy(Y_pred_v,Y_valid)
        print(f"epoch:{num}\ttrain:\tloss:{loss}\taccuracy:{ac}\tvalid:\tloss:{loss_v}\taccuracy:{ac_v}")

epoch:0	train:	loss:1.4151177406311035	accuracy:0.2679640718562874	valid:	loss:1.420035719871521	accuracy:0.2649756645451142
epoch:1	train:	loss:1.3949556350708008	accuracy:0.29565868263473055	valid:	loss:1.3997656106948853	accuracy:0.2942718083114938
epoch:2	train:	loss:1.376684546470642	accuracy:0.3255988023952096	valid:	loss:1.3814066648483276	accuracy:0.31748408835642083
epoch:3	train:	loss:1.3600627183914185	accuracy:0.34805389221556887	valid:	loss:1.3647135496139526	accuracy:0.33536128790715086
epoch:4	train:	loss:1.3448925018310547	accuracy:0.3592814371257485	valid:	loss:1.3494853973388672	accuracy:0.35482965181579934
epoch:5	train:	loss:1.331007957458496	accuracy:0.36826347305389223	valid:	loss:1.3355544805526733	accuracy:0.371209284912018
epoch:6	train:	loss:1.318269968032837	accuracy:0.37350299401197606	valid:	loss:1.322779893875122	accuracy:0.3842193934855859
epoch:7	train:	loss:1.306559443473816	accuracy:0.38323353293413176	valid:	loss:1.3110417127609253	accuracy:0.39470235