In [1]:
from scripts import trees, pq_gram
import scripts.w_pq_batch as w_pq
import torch, pyconll
from sklearn.model_selection import train_test_split

from pqgrams.PQGram import Profile
import random
from tqdm import tqdm

CORPORA =   ["corpora/English-EWT.conllu"]
LABELS = ["EWT"]

### ランダムな完全平衡二分木との比較

In [2]:
train_tensors_path = "data/train_tensors_en_corpora_En_EWT_binary_unlabel_50.pt"
train_labels_path = "data/train_labels_en_corpora_En_EWT_binary_unlabel_50.pt"
train_indexes_path = "data/train_indexes_en_corpora_En_EWT_binary_unlabel_50.pt"

valid_tensors_path = "data/valid_tensors_en_corpora_En_EWT_binary_unlabel_50.pt"
valid_labels_path = "data/valid_labels_en_corpora_En_EWT_binary_unlabel_50.pt"
valid_indexes_path = "data/valid_indexes_en_corpora_En_EWT_binary_unlabel_50.pt"

test_tensors_path = "data/test_tensors_en_corpora_En_EWT_binary_unlabel_50.pt"
test_labels_path = "data/test_labels_en_corpora_En_EWT_binary_unlabel_50.pt"
test_indexes_path = "data/test_indexes_en_corpora_En_EWT_binary_unlabel_50.pt"

model_path = "models/model_en_corpora_En_EWT_binary_unlabel_50.pth"

save_file_list = [
        train_tensors_path, train_labels_path, train_indexes_path,
        valid_tensors_path, valid_labels_path, valid_indexes_path,
        test_tensors_path, test_labels_path, test_indexes_path
    ]

loss_figure_path="figures/En_EWT_binary_unlabel_50.png"


In [3]:
CORPUS_i_LENGTH = []

EWT_conll = []
labels = []

EWT_conll = pyconll.load_from_file(CORPORA[0])
labels = [LABELS[0]]*len(EWT_conll) + ["random_binary"]*10000


pqtree_EWT = [trees.conllTree_to_pqTree_unlabeled(conll.to_tree()) for conll in EWT_conll]
pqIndex = [Profile(tree, p=2, q=2) for tree in pqtree_EWT]


for _ in range(10000):
    height = random.randint(2,8)
    t = trees.create_binary_tree(height, "_")
    pqgram = Profile(t, p=2, q=2)
    pqIndex.append(pqgram)

J =set(pqIndex[0])
for pq_set in pqIndex[1:]:
    J = J.union(pq_set)
J = list(J)

tensors = [pq_gram.pqgram_to_tensor(pqgram, J) for pqgram in tqdm(pqIndex, desc="convert pqgram into tensor\t")]
indexes = torch.Tensor(range(len(labels)))

train_tensors, test_tensors, train_labels, test_labels, train_indexes, test_indexes = train_test_split(tensors, labels, indexes, test_size=0.4, random_state=50) # 無印はrandom state = 42
valid_tensors, test_tensors, valid_labels, test_labels, valid_indexes, test_indexes = train_test_split(test_tensors, test_labels, test_indexes, test_size=0.5, random_state=50)


# データの保存
torch.save(train_tensors, save_file_list[0])
torch.save(train_labels, save_file_list[1])
torch.save(train_indexes, save_file_list[2])

torch.save(valid_tensors, save_file_list[3])
torch.save(valid_labels, save_file_list[4])
torch.save(valid_indexes, save_file_list[5])

torch.save(test_tensors, save_file_list[6])    
torch.save(test_labels, save_file_list[7])
torch.save(test_indexes, save_file_list[8])



print(len(labels), labels[0], labels[-1])


convert pqgram into tensor	: 100%|██████████| 26621/26621 [00:01<00:00, 18960.60it/s]


26621 EWT random_binary


In [None]:
w_pq.train(
        train_tensors_path, train_labels_path, 
        valid_tensors_path, valid_labels_path,
        model_path, loss_figure_path="figures/En_EWT_chatGPT_unlabel_50.png"
)

print('\ntrain finished.')

w_pq.test(
        train_tensors_path, train_labels_path,
        test_tensors_path, test_labels_path, 
        model_path
    )

print('\ntest finished.')

    

  from .autonotebook import tqdm as notebook_tqdm
[train loop]:   2%|▏         | 23/1000 [00:01<00:45, 21.36it/s]


Epoch: 1,	Loss: 0.000784080708399415


[train loop]:   4%|▍         | 45/1000 [00:01<00:21, 44.80it/s]


Epoch: 50,	Loss: 0.00023056216014083475


[train loop]:  10%|▉         | 96/1000 [00:04<00:38, 23.47it/s]


Epoch: 100,	Loss: 4.031254502478987e-05


[train loop]:  15%|█▍        | 149/1000 [00:08<00:39, 21.57it/s]


Epoch: 150,	Loss: 3.931401352019748e-06


[train loop]:  17%|█▋        | 169/1000 [00:11<01:03, 13.16it/s]


Epoch: 200,	Loss: 1.9425941388817591e-07


[train loop]:  23%|██▎       | 233/1000 [00:14<00:43, 17.83it/s]


Epoch: 250,	Loss: 4.20126822220368e-09
