In [12]:
import torch
from torch import nn,optim
from torch.utils.data import TensorDataset,DataLoader
import collections
import numpy as np
    
class CNN(nn.Module):
    def __init__(self,vocab,emb,pad,out,hid,max_len,token_size,emb_weights=None,bidirectional=False):
        super().__init__()        
        self.emb=nn.Embedding(vocab,emb,padding_idx=pad)
        self.conv=nn.Conv1d(emb,hid,token_size,padding=1)
        self.relu=nn.ReLU()
        self.pool=nn.MaxPool1d(max_len)
        self.linear=nn.Linear(hid,out)
        self.softmax=nn.Softmax(dim=1)
    def forward(self, x,h=None):
        x = self.emb(x)
        x = x.view(x.shape[0],x.shape[2],x.shape[1])
        x = self.conv(x)
        x = self.relu(x)
        x = x.view(x.shape[0],x.shape[1],x.shape[2])
        x = self.pool(x)
        x = x.view(x.shape[0],x.shape[1])
        y = self.linear(x)
        y=self.softmax(y)
        return y
def accuracy(pred,label):
    pred=torch.argmax(pred,dim=-1)
    leng=len(pred)
    return (pred==label).sum().item()/leng

def list2tensor(data,max_len,pad):
    new_list=[]
    for d in data:
        if len(d)>max_len:
            d=d[:max_len]
        else:
            d+=[pad]*(max_len-len(d))
        new_list.append(d)
    return torch.tensor(new_list)

def count(file):
    all_word=[]
    with open(file,"r",encoding="utf-8") as tf:
        tags=[]
        tag_num={"b":0,"t":1,"e":2,"m":3}
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            tags.append(tag_num[tag])
            words=sen.strip().split(" ")
            for word in words:
                all_word.append(word)
        freq=collections.Counter(all_word)
        freq_sort=sorted(freq.items(),key=lambda x:x[1],reverse=True)
        #print(freq_sort)
        word_number={w:i+1 for i,(w,f) in enumerate(freq_sort) if f>=2}
        return word_number,tags

def id_v(file,word_id):
    with open(file,"r",encoding="utf-8") as tf:
        ids=[]
        for line in tf:
            line.strip();
            tag,sen=line.split("\t")
            words=sen.strip().split(" ")
            sent_ids=[]
            for word in words:
                if word in word_id:
                    sent_ids.append(word_id[word])
                else:
                    sent_ids.append(0)
            ids.append(sent_ids)
        return ids
    
word_id,Y_train=count("train.feature.txt")
word_id_valid,Y_valid=count("valid.feature.txt")
max_len=10
vocab=len(word_id)+1
emb=300
pad=len(word_id)
hid=50
out=4

X_train=id_v("train.feature.txt",word_id)
X_train=list2tensor(X_train,max_len,pad)
Y_train=torch.tensor(Y_train)

X_valid=id_v("valid.feature.txt",word_id)
X_valid=list2tensor(X_valid,max_len,pad)
Y_valid=torch.tensor(Y_valid)

dataset_train=TensorDataset(X_train,Y_train)


loader=DataLoader(dataset_train,batch_size=1024)
loss_fn=nn.CrossEntropyLoss()

token_size=[2,3]
L=[10**i for i in range(-2,1)]
for fl in token_size:
    for l in L:
        optimizer=torch.optim.SGD(model.parameters(), lr=l)
        model=CNN(vocab,emb,pad,out,hid,max_len,fl,emb_weights=None,bidirectional=True)

        for num in range(10):
            for X,Y in loader:
                Y_pred=model(X)
                loss=loss_fn(Y_pred,Y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            with torch.no_grad():
                Y_pred=model(X_train)
                loss=loss_fn(Y_pred,Y_train)
                ac=accuracy(Y_pred,Y_train)

                Y_pred_v=model(X_valid)
                loss_v=loss_fn(Y_pred_v,Y_valid)
                ac_v=accuracy(Y_pred_v,Y_valid)
        print(f"フィルターサイズ{fl},学習率{l}: train: loss:{loss}\taccuracy:{ac}\tvalid: loss:{loss_v}\taccuracy:{ac_v}")
    


フィルターサイズ2,学習率0.01: train: loss:1.3237619400024414	accuracy:0.4014414077124672	valid: loss:1.32319974899292	accuracy:0.3997005988023952
フィルターサイズ2,学習率0.1: train: loss:1.4199168682098389	accuracy:0.17184575065518531	valid: loss:1.4194891452789307	accuracy:0.1751497005988024
フィルターサイズ2,学習率1: train: loss:1.3608176708221436	accuracy:0.4206289779108948	valid: loss:1.3615612983703613	accuracy:0.4176646706586826
フィルターサイズ3,学習率0.01: train: loss:1.3560837507247925	accuracy:0.3801946836390865	valid: loss:1.3556005954742432	accuracy:0.3847305389221557
フィルターサイズ3,学習率0.1: train: loss:1.3262617588043213	accuracy:0.42016098839385996	valid: loss:1.3256176710128784	accuracy:0.4244011976047904
フィルターサイズ3,学習率1: train: loss:1.3971152305603027	accuracy:0.2458816922500936	valid: loss:1.3953319787979126	accuracy:0.26047904191616766
