In [17]:
import torch
from torch import nn,optim
from torch.utils.data import TensorDataset,DataLoader
import collections
import numpy as np

from gensim.models import KeyedVectors
vectors=KeyedVectors.load_word2vec_format("GoogleNews-vectors-negative300.bin.gz",binary=True)

class RNN(nn.Module):
    def __init__(self,vocab,emb,pad,out,hid,emb_weights=None):
        super().__init__()
        self.hidden_size=hid
        if emb_weights!=None:
            self.emb=nn.Embedding.from_pretrained(emb_weights,padding_idx=pad)
        else:
            self.emb=nn.Embedding(vocab,emb,padding_idx=pad)
        self.rnn=nn.RNN(emb,hid,batch_first=True)
        self.linear=nn.Linear(hid,out)
        self.softmax=nn.Softmax()
        
    def forward(self,x,h=None):
        x=self.emb(x)
        y,h=self.rnn(x,h)
        y=y[:,-1,:]
        y=self.linear(y)
        return y
    
def accuracy(pred,label):
    pred=torch.argmax(pred,dim=-1)
    leng=len(pred)
    return (pred==label).sum().item()/leng

def list2tensor(data,max_len,pad):
    new_list=[]
    for d in data:
        if len(d)>max_len:
            d=d[:max_len]
        else:
            d+=[pad]*(max_len-len(d))
        new_list.append(d)
    return torch.tensor(new_list)

def count(file):
    all_word=[]
    with open(file,"r",encoding="utf-8") as tf:
        tags=[]
        tag_num={"b":0,"t":1,"e":2,"m":3}
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            tags.append(tag_num[tag])
            words=sen.strip().split(" ")
            for word in words:
                all_word.append(word)
        freq=collections.Counter(all_word)
        freq_sort=sorted(freq.items(),key=lambda x:x[1],reverse=True)
        #print(freq_sort)
        word_number={w:i+1 for i,(w,f) in enumerate(freq_sort) if f>=2}
        return word_number,tags

def id_v(file,word_id):
    with open(file,"r",encoding="utf-8") as tf:
        ids=[]
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            words=sen.strip().split(" ")
            sent_ids=[]
            for word in words:
                if word in word_id:
                    sent_ids.append(word_id[word])
                else:
                    sent_ids.append(0)
            ids.append(sent_ids)
        return ids
    
def init_emb(vocab,emb,ids):
    weights=np.zeros((vocab,emb))
    words_in_pretrained=0
    for i,word in enumerate(ids.keys()):
        try:
            weights[i]=vectors[word]
            words_in_pretrained+=1
        except KeyError:
            weights[i]=np.random.randn(emb)
    weights=torch.from_numpy(weights.astype((np.float32)))
    return weights

word_id,Y_train=count("train.feature.txt")
word_id_valid,Y_valid=count("valid.feature.txt")
max_len=10
vocab=len(word_id)+1
emb=300
pad=len(word_id)
hid=50
out=4

X_train=id_v("train.feature.txt",word_id)
X_train=list2tensor(X_train,max_len,pad)
Y_train=torch.tensor(Y_train)

X_valid=id_v("valid.feature.txt",word_id)
X_valid=list2tensor(X_valid,max_len,pad)
Y_valid=torch.tensor(Y_valid)

ew=init_emb(vocab,emb,word_id)

model=RNN(vocab,emb,pad,out,hid,emb_weights=ew)

dataset_train=TensorDataset(X_train,Y_train)

loader=DataLoader(dataset_train,batch_size=1024)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(), lr=0.1)
epoch=10
for num in range(epoch):
    for X,Y in loader:
        Y_pred=model(X)
        loss=loss_fn(Y_pred,Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        Y_pred=model(X_train)
        loss=loss_fn(Y_pred,Y_train)
        ac=accuracy(Y_pred,Y_train)
        
        Y_pred_v=model(X_valid)
        loss_v=loss_fn(Y_pred_v,Y_valid)
        ac_v=accuracy(Y_pred_v,Y_valid)
        print(f"epoch:{num}\ttrain:\tloss:{loss}\taccuracy:{ac}\tvalid:\tloss:{loss_v}\taccuracy:{ac_v}")

epoch:0	train:	loss:1.2171216011047363	accuracy:0.4184131736526946	valid:	loss:1.2207508087158203	accuracy:0.4353238487457881
epoch:1	train:	loss:1.1689057350158691	accuracy:0.42889221556886226	valid:	loss:1.1710766553878784	accuracy:0.44627480344440285
epoch:2	train:	loss:1.1607307195663452	accuracy:0.44011976047904194	valid:	loss:1.1625874042510986	accuracy:0.4514226881317858
epoch:3	train:	loss:1.1572202444076538	accuracy:0.4438622754491018	valid:	loss:1.1591989994049072	accuracy:0.45582178959191316
epoch:4	train:	loss:1.1545565128326416	accuracy:0.4491017964071856	valid:	loss:1.1567227840423584	accuracy:0.46050168476226133
epoch:5	train:	loss:1.1522235870361328	accuracy:0.4491017964071856	valid:	loss:1.1545758247375488	accuracy:0.4634032197678772
epoch:6	train:	loss:1.1500414609909058	accuracy:0.45583832335329344	valid:	loss:1.1525777578353882	accuracy:0.4669599400973418
epoch:7	train:	loss:1.1479121446609497	accuracy:0.45583832335329344	valid:	loss:1.1506445407867432	accuracy:0.47