In [2]:
import torch
from torch import nn,optim
from torch.utils.data import TensorDataset,DataLoader
import collections
import numpy as np
    
class CNN(nn.Module):
    def __init__(self,vocab,emb,pad,out,hid,max_len,emb_weights=None,bidirectional=False):
        super().__init__()        
        self.emb=nn.Embedding(vocab,emb,padding_idx=pad)
        self.conv=nn.Conv1d(emb,hid,3,padding=1)
        self.relu=nn.ReLU()
        self.pool=nn.MaxPool1d(max_len)
        self.linear=nn.Linear(hid,out)
        self.softmax=nn.Softmax(dim=1)
    def forward(self, x,h=None):
        x = self.emb(x)
        x = x.view(x.shape[0],x.shape[2],x.shape[1])
        x = self.conv(x)
        x = self.relu(x)
        x = x.view(x.shape[0],x.shape[1],x.shape[2])
        x = self.pool(x)
        x = x.view(x.shape[0],x.shape[1])
        y = self.linear(x)
        y=self.softmax(y)
        return y
def accuracy(pred,label):
    pred=torch.argmax(pred,dim=-1)
    leng=len(pred)
    return (pred==label).sum().item()/leng

def list2tensor(data,max_len,pad):
    new_list=[]
    for d in data:
        if len(d)>max_len:
            d=d[:max_len]
        else:
            d+=[pad]*(max_len-len(d))
        new_list.append(d)
    return torch.tensor(new_list)

def count(file):
    all_word=[]
    with open(file,"r",encoding="utf-8") as tf:
        tags=[]
        tag_num={"b":0,"t":1,"e":2,"m":3}
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            tags.append(tag_num[tag])
            words=sen.strip().split(" ")
            for word in words:
                all_word.append(word)
        freq=collections.Counter(all_word)
        freq_sort=sorted(freq.items(),key=lambda x:x[1],reverse=True)
        #print(freq_sort)
        word_number={w:i+1 for i,(w,f) in enumerate(freq_sort) if f>=2}
        return word_number,tags

def id_v(file,word_id):
    with open(file,"r",encoding="utf-8") as tf:
        ids=[]
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            words=sen.strip().split(" ")
            sent_ids=[]
            for word in words:
                if word in word_id:
                    sent_ids.append(word_id[word])
                else:
                    sent_ids.append(0)
            ids.append(sent_ids)
        return ids
    
word_id,Y_train=count("train.feature.txt")
word_id_valid,Y_valid=count("valid.feature.txt")
max_len=10
vocab=len(word_id)+1
emb=300
pad=len(word_id)
hid=50
out=4

X_train=id_v("train.feature.txt",word_id)
X_train=list2tensor(X_train,max_len,pad)
Y_train=torch.tensor(Y_train)

X_valid=id_v("valid.feature.txt",word_id)
X_valid=list2tensor(X_valid,max_len,pad)
Y_valid=torch.tensor(Y_valid)

dataset_train=TensorDataset(X_train,Y_train)

model=CNN(vocab,emb,pad,out,hid,max_len,emb_weights=None,bidirectional=True)
loader=DataLoader(dataset_train,batch_size=1024)
loss_fn=nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(), lr=0.1)
epoch=10
for num in range(epoch):
    for X,Y in loader:
        Y_pred=model(X)
        loss=loss_fn(Y_pred,Y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        Y_pred=model(X_train)
        loss=loss_fn(Y_pred,Y_train)
        ac=accuracy(Y_pred,Y_train)
        
        Y_pred_v=model(X_valid)
        loss_v=loss_fn(Y_pred_v,Y_valid)
        ac_v=accuracy(Y_pred_v,Y_valid)
        print(f"epoch:{num}\ttrain:\tloss:{loss}\taccuracy:{ac}\tvalid:\tloss:{loss_v}\taccuracy:{ac_v}")

epoch:0	train:	loss:1.2834141254425049	accuracy:0.4239049045301385	valid:	loss:1.284889817237854	accuracy:0.41691616766467066
epoch:1	train:	loss:1.2605077028274536	accuracy:0.46780232122800447	valid:	loss:1.2617436647415161	accuracy:0.4513473053892216
epoch:2	train:	loss:1.2473677396774292	accuracy:0.510295769374766	valid:	loss:1.2486563920974731	accuracy:0.5052395209580839
epoch:3	train:	loss:1.235690712928772	accuracy:0.5441782104080869	valid:	loss:1.237125039100647	accuracy:0.5396706586826348
epoch:4	train:	loss:1.2236007452011108	accuracy:0.5741295394983152	valid:	loss:1.2252192497253418	accuracy:0.5718562874251497
epoch:5	train:	loss:1.2106983661651611	accuracy:0.5967802321228004	valid:	loss:1.2125215530395508	accuracy:0.5920658682634731
epoch:6	train:	loss:1.1971592903137207	accuracy:0.6154062149007862	valid:	loss:1.199273705482483	accuracy:0.6032934131736527
epoch:7	train:	loss:1.1833938360214233	accuracy:0.6282291276675402	valid:	loss:1.1857844591140747	accuracy:0.620508982035