In [1]:
import torch
from torch import nn,optim
from torch.utils.data import TensorDataset,DataLoader
import collections
import numpy as np
    
class CNN(nn.Module):
    def __init__(self,vocab,emb,pad,out,hid,max_len,emb_weights=None,bidirectional=False):
        super().__init__()        
        self.emb=nn.Embedding(vocab,emb,padding_idx=pad)
        self.conv=nn.Conv1d(emb,hid,3,padding=1)
        self.relu=nn.ReLU()
        self.pool=nn.MaxPool1d(max_len)
        self.linear=nn.Linear(hid,out)
        self.softmax=nn.Softmax(dim=1)
    def forward(self, x,h=None):
        x = self.emb(x)
        x = x.view(x.shape[0],x.shape[2],x.shape[1])
        x = self.conv(x)
        x = self.relu(x)
        x = x.view(x.shape[0],x.shape[1],x.shape[2])
        x = self.pool(x)
        x = x.view(x.shape[0],x.shape[1])
        y = self.linear(x)
        y=self.softmax(y)
        return y
def accuracy(pred,label):
    pred=torch.argmax(pred,dim=-1)
    leng=len(pred)
    return (pred==label).sum().item()/leng

def list2tensor(data,max_len,pad):
    new_list=[]
    for d in data:
        if len(d)>max_len:
            d=d[:max_len]
        else:
            d+=[pad]*(max_len-len(d))
        new_list.append(d)
    return torch.tensor(new_list)

def count(file):
    all_word=[]
    with open(file,"r",encoding="utf-8") as tf:
        tags=[]
        tag_num={"b":0,"t":1,"e":2,"m":3}
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            tags.append(tag_num[tag])
            words=sen.strip().split(" ")
            for word in words:
                all_word.append(word)
        freq=collections.Counter(all_word)
        freq_sort=sorted(freq.items(),key=lambda x:x[1],reverse=True)
        #print(freq_sort)
        word_number={w:i+1 for i,(w,f) in enumerate(freq_sort) if f>=2}
        return word_number,tags

def id_v(file,word_id):
    with open(file,"r",encoding="utf-8") as tf:
        ids=[]
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            words=sen.strip().split(" ")
            sent_ids=[]
            for word in words:
                if word in word_id:
                    sent_ids.append(word_id[word])
                else:
                    sent_ids.append(0)
            ids.append(sent_ids)
        return ids
    
word_id,Y_train=count("train.feature.txt")
word_id_valid,Y_valid=count("valid.feature.txt")
max_len=10
vocab=len(word_id)+1
emb=300
pad=len(word_id)
hid=50
out=4

X_train=id_v("train.feature.txt",word_id)
X_train=list2tensor(X_train,max_len,pad)
Y_train=torch.tensor(Y_train)

dataset_train=TensorDataset(X_train,Y_train)

model=CNN(vocab,emb,pad,out,hid,max_len,emb_weights=None,bidirectional=True)
with torch.no_grad():
    Y_pred=model(X_train)
    print(Y_pred)

tensor([[0.3752, 0.1877, 0.2642, 0.1729],
        [0.4253, 0.2414, 0.1890, 0.1443],
        [0.3624, 0.2219, 0.2103, 0.2053],
        ...,
        [0.1913, 0.3041, 0.3219, 0.1828],
        [0.4040, 0.1956, 0.2848, 0.1156],
        [0.3977, 0.2660, 0.1783, 0.1580]])
