In [1]:
import torch
from torch import nn,optim
from torch.utils.data import TensorDataset,DataLoader
import collections

class RNN(nn.Module):
    def __init__(self,vocab,emb,pad,out,hid):
        super().__init__()
        self.hidden_size=hid
        self.emb=nn.Embedding(vocab,emb,padding_idx=pad)
        self.rnn=nn.RNN(emb,hid,batch_first=True)
        self.linear=nn.Linear(hid,out)
        self.softmax=nn.Softmax(dim=1)
        
    def forward(self,x,h=None):
        x=self.emb(x)
        y,h=self.rnn(x,h)
        y=y[:,-1,:]
        y=self.linear(y)
        return y
    
def accuracy(pred,label):
    pred=torch.argmax(pred,dim=-1)
    leng=len(pred)
    return (pred==label).sum().item()/leng

def list2tensor(data,max_len,pad):
    new_list=[]
    for d in data:
        if len(d)>max_len:
            d=d[:max_len]
        else:
            d+=[pad]*(max_len-len(d))
        new_list.append(d)
    return torch.tensor(new_list)

def count(file):
    all_word=[]
    with open(file,"r",encoding="utf-8") as tf:
        tags=[]
        tag_num={"b":0,"t":1,"e":2,"m":3}
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            tags.append(tag_num[tag])
            words=sen.strip().split(" ")
            for word in words:
                all_word.append(word)
        freq=collections.Counter(all_word)
        freq_sort=sorted(freq.items(),key=lambda x:x[1],reverse=True)
        #print(freq_sort)
        word_number={w:i+1 for i,(w,f) in enumerate(freq_sort) if f>=2}
        return word_number,tags

def id_v(file,word_id):
    with open(file,"r",encoding="utf-8") as tf:
        ids=[]
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            words=sen.strip().split(" ")
            sent_ids=[]
            for word in words:
                if word in word_id:
                    sent_ids.append(word_id[word])
                else:
                    sent_ids.append(0)
            ids.append(sent_ids)
        return ids    

word_id,Y_train=count("train.feature.txt")
word_id_valid,Y_valid=count("valid.feature.txt")
max_len=10
vocab=len(word_id)+1
emb=300
pad=len(word_id)
hid=50
out=4

X_train=id_v("train.feature.txt",word_id)
X_train=list2tensor(X_train,max_len,pad)
Y_train=torch.tensor(Y_train)

X_valid=id_v("valid.feature.txt",word_id)
X_valid=list2tensor(X_valid,max_len,pad)
Y_valid=torch.tensor(Y_valid)

model=RNN(vocab,emb,pad,out,hid)
dataset_train=TensorDataset(X_train,Y_train)

loader=DataLoader(dataset_train,batch_size=1)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(), lr=0.01)
epoch=3
for num in range(epoch):
    for X,Y in loader:
        Y_pred=model(X)
        loss=loss_fn(Y_pred,Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        Y_pred=model(X_train)
        loss=loss_fn(Y_pred,Y_train)
        ac=accuracy(Y_pred,Y_train)
        
        Y_pred_v=model(X_valid)
        loss_v=loss_fn(Y_pred_v,Y_valid)
        ac_v=accuracy(Y_pred_v,Y_valid)
        print(f"epoch:{num}\ttrain:\tloss:{loss}\taccuracy:{ac}\tvalid:\tloss:{loss_v}\taccuracy:{ac_v}")

epoch:0	train:	loss:0.946532666683197	accuracy:0.6710969674279296	valid:	loss:0.8978021144866943	accuracy:0.6796407185628742
epoch:1	train:	loss:0.928434431552887	accuracy:0.6441407712467241	valid:	loss:0.8812766671180725	accuracy:0.6609281437125748
epoch:2	train:	loss:0.9729471802711487	accuracy:0.6256083863721452	valid:	loss:0.9141689538955688	accuracy:0.6586826347305389
