In [1]:
import torch.nn.functional as F
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from my_module.tools import BertToEmoFileDataset, BertToEmoDirectDataset, calc_loss, EarlyStopping
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Trainデータセットの準備
# データセットが存在するディレクトリを指定

#ファイル分割バージョン(省メモリ設計)
# dataset_root_dir = "/workspace/dataset/data_src/BERT_to_emotion/only_emotion/train/split/"
# train_dataset = BertToEmoFileDataset(dataset_root_dir)

#一気に読み込みバージョン(TESLA用)
dataset_dir = "/workspace/dataset/data_src/BERT_to_emotion/only_emotion/train/BERT_to_emo_train.txt"
train_dataset = BertToEmoDirectDataset(dataset_dir)

# valデータセットの準備

# ファイル分割バージョン
# dataset_root_dir = "/workspace/dataset/data_src/BERT_to_emotion/only_emotion/val/split/"
# val_dataset = BertToEmoFileDataset(dataset_root_dir)

# 一気に読み込みバージョン
dataset_dir = "/workspace/dataset/data_src/BERT_to_emotion/only_emotion/val/BERT_to_emo_val.txt"
val_dataset = BertToEmoDirectDataset(dataset_dir)

# testデータセットの準備

# ファイル分割バージョン
# dataset_root_dir = "/workspace/dataset/data_src/BERT_to_emotion/only_emotion/test/split/"
# test_dataset = BertToEmoFileDataset(dataset_root_dir)

# 一気に読み込みバージョン
dataset_dir = "/workspace/dataset/data_src/BERT_to_emotion/only_emotion/test/BERT_to_emo_test.txt"
test_dataset = BertToEmoDirectDataset(dataset_dir)

In [29]:
# ハイパーパラメータ
batch_size = 16384
max_epoch = 10000

# 設定
num_workers = 0

In [30]:
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        self.bn = nn.BatchNorm1d(768)
        self.fc1 = nn.Linear(768, 400)
        self.fc2 = nn.Linear(400, 100)
        self.fc3 = nn.Linear(100, 10)

    def forward(self, x):
        x = self.bn(x)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x


In [31]:
# GPUの設定状況に基づいたデバイスの選択
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("device:", device)

# 乱数シードを固定して再現性を確保
torch.manual_seed(0)

# インスタンス化・デバイスへの転送
net = Net().to(device)

# 損失関数の選択
criterion = nn.CrossEntropyLoss()

# 最適化手法の選択
optimizer = torch.optim.Adam(net.parameters())

device: cuda:0


In [32]:
# FileDataloader
train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size = batch_size, num_workers=num_workers)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, num_workers=num_workers)

In [33]:
# early stopping
earlystopping  = EarlyStopping(patience=5, verbose=True)

In [34]:
# ネットワークの学習
for epoch in range(max_epoch):
	print("epoch:", epoch)
	loss_list = []
	for batch in tqdm(train_dataloader):
		x, t = batch

		x = x.to(device)
		t = t.to(device)

		optimizer.zero_grad()

		y = net(x)

		loss = criterion(y, t)

		loss.backward()

		loss_list.append(loss)

		optimizer.step()
		
	train_loss_avg = torch.tensor(loss_list).mean()
	print("val_loss calc...")
	val_loss_avg = calc_loss(net, val_dataloader, criterion, device)
	print("train_loss: {}, val_loss: {}".format(train_loss_avg, val_loss_avg))
	earlystopping(val_loss_avg, net) #callメソッド呼び出し
	if earlystopping.early_stop: #ストップフラグがTrueの場合、breakでforループを抜ける
		print("Early Stopping!")
		break

epoch: 0


100%|██████████| 236/236 [00:45<00:00,  5.19it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.57it/s]


train_loss: 1.326688289642334, val_loss: 1.128109097480774
Validation loss decreased (inf --> 1.128109).  Saving model ...
epoch: 1


100%|██████████| 236/236 [00:44<00:00,  5.26it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.83it/s]


train_loss: 1.1212821006774902, val_loss: 1.1139591932296753
Validation loss decreased (1.128109 --> 1.113959).  Saving model ...
epoch: 2


100%|██████████| 236/236 [00:45<00:00,  5.19it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  8.62it/s]


train_loss: 1.1116551160812378, val_loss: 1.1092394590377808
Validation loss decreased (1.113959 --> 1.109239).  Saving model ...
epoch: 3


100%|██████████| 236/236 [00:51<00:00,  4.60it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.88it/s]


train_loss: 1.107024073600769, val_loss: 1.1065455675125122
Validation loss decreased (1.109239 --> 1.106546).  Saving model ...
epoch: 4


100%|██████████| 236/236 [00:46<00:00,  5.04it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.71it/s]


train_loss: 1.1040540933609009, val_loss: 1.1050387620925903
Validation loss decreased (1.106546 --> 1.105039).  Saving model ...
epoch: 5


100%|██████████| 236/236 [00:44<00:00,  5.26it/s]


val_loss calc...


100%|██████████| 13/13 [00:05<00:00,  2.53it/s]


train_loss: 1.1018798351287842, val_loss: 1.1037087440490723
Validation loss decreased (1.105039 --> 1.103709).  Saving model ...
epoch: 6


100%|██████████| 236/236 [00:45<00:00,  5.19it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.73it/s]


train_loss: 1.1002538204193115, val_loss: 1.1031211614608765
Validation loss decreased (1.103709 --> 1.103121).  Saving model ...
epoch: 7


100%|██████████| 236/236 [00:44<00:00,  5.26it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00, 10.55it/s]


train_loss: 1.0989019870758057, val_loss: 1.1031231880187988
EarlyStopping counter: 1 out of 5
epoch: 8


100%|██████████| 236/236 [00:45<00:00,  5.18it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.68it/s]


train_loss: 1.0978749990463257, val_loss: 1.102673053741455
Validation loss decreased (1.103121 --> 1.102673).  Saving model ...
epoch: 9


100%|██████████| 236/236 [00:45<00:00,  5.21it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.77it/s]


train_loss: 1.0970416069030762, val_loss: 1.102394938468933
Validation loss decreased (1.102673 --> 1.102395).  Saving model ...
epoch: 10


100%|██████████| 236/236 [00:45<00:00,  5.24it/s]


val_loss calc...


100%|██████████| 13/13 [00:05<00:00,  2.38it/s]


train_loss: 1.0964163541793823, val_loss: 1.1027823686599731
EarlyStopping counter: 1 out of 5
epoch: 11


100%|██████████| 236/236 [00:45<00:00,  5.23it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00, 10.47it/s]


train_loss: 1.0959291458129883, val_loss: 1.102644681930542
EarlyStopping counter: 2 out of 5
epoch: 12


100%|██████████| 236/236 [00:45<00:00,  5.24it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.60it/s]


train_loss: 1.0955860614776611, val_loss: 1.102810025215149
EarlyStopping counter: 3 out of 5
epoch: 13


100%|██████████| 236/236 [00:44<00:00,  5.28it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.22it/s]


train_loss: 1.0952938795089722, val_loss: 1.1022330522537231
Validation loss decreased (1.102395 --> 1.102233).  Saving model ...
epoch: 14


100%|██████████| 236/236 [00:45<00:00,  5.20it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.67it/s]


train_loss: 1.094997525215149, val_loss: 1.1022905111312866
EarlyStopping counter: 1 out of 5
epoch: 15


100%|██████████| 236/236 [00:44<00:00,  5.30it/s]


val_loss calc...


100%|██████████| 13/13 [00:05<00:00,  2.37it/s]


train_loss: 1.0947279930114746, val_loss: 1.102331519126892
EarlyStopping counter: 2 out of 5
epoch: 16


100%|██████████| 236/236 [00:44<00:00,  5.27it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00, 10.12it/s]


train_loss: 1.0945029258728027, val_loss: 1.1023938655853271
EarlyStopping counter: 3 out of 5
epoch: 17


100%|██████████| 236/236 [00:44<00:00,  5.26it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  8.49it/s]


train_loss: 1.0942676067352295, val_loss: 1.1020886898040771
Validation loss decreased (1.102233 --> 1.102089).  Saving model ...
epoch: 18


100%|██████████| 236/236 [00:44<00:00,  5.29it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00, 10.09it/s]


train_loss: 1.0940532684326172, val_loss: 1.1022896766662598
EarlyStopping counter: 1 out of 5
epoch: 19


100%|██████████| 236/236 [00:45<00:00,  5.20it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  8.59it/s]


train_loss: 1.0938992500305176, val_loss: 1.1021944284439087
EarlyStopping counter: 2 out of 5
epoch: 20


100%|██████████| 236/236 [00:44<00:00,  5.31it/s]


val_loss calc...


100%|██████████| 13/13 [00:05<00:00,  2.36it/s]


train_loss: 1.0936589241027832, val_loss: 1.101881504058838
Validation loss decreased (1.102089 --> 1.101882).  Saving model ...
epoch: 21


100%|██████████| 236/236 [00:45<00:00,  5.23it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.61it/s]


train_loss: 1.0935531854629517, val_loss: 1.1021755933761597
EarlyStopping counter: 1 out of 5
epoch: 22


100%|██████████| 236/236 [00:45<00:00,  5.23it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  8.16it/s]


train_loss: 1.093446969985962, val_loss: 1.102455973625183
EarlyStopping counter: 2 out of 5
epoch: 23


100%|██████████| 236/236 [00:45<00:00,  5.15it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.75it/s]


train_loss: 1.0932910442352295, val_loss: 1.1025288105010986
EarlyStopping counter: 3 out of 5
epoch: 24


100%|██████████| 236/236 [00:45<00:00,  5.18it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00, 10.25it/s]


train_loss: 1.0931280851364136, val_loss: 1.102591633796692
EarlyStopping counter: 4 out of 5
epoch: 25


100%|██████████| 236/236 [00:44<00:00,  5.28it/s]


val_loss calc...


100%|██████████| 13/13 [00:01<00:00,  9.57it/s]


train_loss: 1.0930149555206299, val_loss: 1.1028592586517334
EarlyStopping counter: 5 out of 5
Early Stopping!


In [35]:
test_loss_avg = calc_loss(net, val_dataloader, criterion, device)
print("test_loss: {}".format(test_loss_avg))
print("done!")

100%|██████████| 13/13 [00:05<00:00,  2.36it/s]

test_loss: 1.1028592586517334
done!



