In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
from torchviz import make_dot
from tqdm import tqdm
import wandb
import yaml
from torch.utils.data import DataLoader

In [None]:
# wandb.login() # 一度loginしたら不要

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)# cudaが利用可能だと0

In [None]:
# データセットの定義
data_root = './data'
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5),
    transforms.Lambda(lambda x: x.view(-1)),
])
# 最初に実行するときはdownloadをTrue
train_set = datasets.MNIST(
    root = data_root, train=True, download=False, transform=transform
)
test_set = datasets.MNIST(
    root = data_root, train=False, download=False, transform=transform
)
# NN層の定義
torch.manual_seed(123)
torch.cuda.manual_seed(123)

In [None]:
# モデルの定義
# 784入力10出力２隠れ層のNNモデル
class Net1(nn.Module):
    def __init__(self, n_input, n_output, n_hidden):
        super().__init__()
        # 隠れ層1
        self.l1 = nn.Linear(n_input, n_hidden)
        # 出力層
        self.l2 = nn.Linear(n_hidden, n_output)
        # ReLU
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x1 = self.l1(x)
        x2 = self.relu(x1)
        x3 = self.l2(x2)
        return x3

In [None]:
# モデルの定義
# 784入力10出力２隠れ層のNNモデル
class Net2(nn.Module):
    def __init__(self, n_input, n_output, n_hidden):
        super().__init__()
        # 隠れ層1
        self.l1 = nn.Linear(n_input, n_hidden)
        # 隠れ層2
        self.l2 = nn.Linear(n_hidden, n_hidden)
        # 出力層
        self.l3 = nn.Linear(n_hidden, n_output)
        # ReLU
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x1 = self.l1(x)
        x2 = self.relu(x1)
        x3 = self.l2(x2)
        x4 = self.relu(x3)
        x5 = self.l3(x4)
        return x5

In [None]:
class ExecClass:
    def __init__(self):
        # 損失関数：交差エントロピー関数
        self.criterion = nn.CrossEntropyLoss()


    def init_sweep_parameter(self):
        # sweepパラメータ設定   
        lr = wandb.config["learning_rate"]
        self.batch_size = wandb.config["batch_size"]
        self.num_epochs = wandb.config["epochs"]
        print("learning_rate:", lr)
        print("batch_size:", self.batch_size)
        print("num_epochs:", self.num_epochs)

        # 入出力、隠れ層の設定
        self.train_loader = DataLoader(
            train_set, batch_size = self.batch_size,
            shuffle = True
        )
        self.test_loader = DataLoader(
            test_set, batch_size = self.batch_size,
            shuffle = False
        )
        image, label = train_set[0]
        n_input = image.shape[0]
        for images, labels in self.train_loader:
            break
        n_output = len(set(list(labels.data.numpy())))
        n_hidden = 128

        # モデル初期化
        self.net = Net1(n_input, n_output, n_hidden).to(device)
        print(self.net)
        # 最適化関数
        self.optimizer = optim.SGD(self.net.parameters(), lr=lr)
        # 評価関数結果記録用
        self.history = np.zeros((0,5))

    def train_one_epoch(self, train_loss, train_acc, n_train):
        # 訓練フェーズ
        for inputs, labels in tqdm(self.train_loader):
            n_train += len(labels)
            # GPUへ転送
            inputs = inputs.to(device)
            labels = labels.to(device)
            # 勾配の初期化
            self.optimizer.zero_grad()
            # 予測計算
            outputs = self.net(inputs)
            # 損失計算
            loss = self.criterion(outputs, labels)
            # 勾配計算
            loss.backward()
            # パラメータ修正
            self.optimizer.step()
            # 予測ラベル導出
            predicted = torch.max(outputs, 1)[1]
            # 損失と精度の計算
            train_loss += loss.item()
            train_acc += (predicted == labels).sum().item()
        return train_loss, train_acc, n_train

    def evaluate_one_epoch(self, test_loss, test_acc, n_test ):    
        # 予測フェーズ
        for inputs_test, labels_test in self.test_loader:
            n_test += len(labels_test)
            inputs_test = inputs_test.to(device)
            labels_test = labels_test.to(device)

            # 予測計算
            outputs_test = self.net(inputs_test)

            # 損失関数
            loss_test = self.criterion(outputs_test, labels_test)


            # 予測データ導出
            predicted_test = torch.max(outputs_test, 1)[1]

            # 損失と制度の計算
            test_loss += loss_test.item()
            test_acc += (predicted_test == labels_test).sum().item()
        return test_loss, test_acc, n_test

    def plot_loss_history(self):
        # 学習曲線の表示(損失)
        plt.plot(self.history[:,0], self.history[:,1], "b", label="train")
        plt.plot(self.history[:,0], self.history[:,3], "k", label="test")
        plt.xlabel('iterate number')
        plt.ylabel('loss')
        plt.title('learning curve')
        plt.legend()
        plt.show()

    def plot_acc_history(self):
        # 学習曲線の表示(精度)
        plt.plot(self.history[:,0], self.history[:,2], "b", label="train")
        plt.plot(self.history[:,0], self.history[:,4], "k", label="test")
        plt.xlabel('iterate number')
        plt.ylabel('acc')
        plt.title('learning curve(acc)')
        plt.legend()
        plt.show()


    def main(self):
        wandb.init()
        self.init_sweep_parameter()
        
        # 繰り返し計算メインループ
        for epoch in range(self.num_epochs):
            train_acc, train_loss = 0, 0
            test_acc, test_loss = 0, 0
            n_train, n_test = 0, 0

            train_loss, train_acc, n_train = self.train_one_epoch(train_loss, train_acc, n_train)
            test_loss, test_acc, n_test = self.evaluate_one_epoch(test_loss, test_acc, n_test )
            

            # 評価値の算出・記録
            train_acc = train_acc / n_train
            test_acc = test_acc / n_test
            train_loss = train_loss * self.batch_size / n_train
            test_loss = test_loss * self.batch_size / n_test
            print (f'Epoch [{epoch+1}/{self.num_epochs}], loss: {train_loss:.5f} acc: {train_acc:.5f} val_loss: {test_loss:.5f}, val_acc: {test_acc:.5f}')
            item = np.array([epoch+1 , train_loss, train_acc, test_loss, test_acc])
            self.history = np.vstack((self.history, item))
            wandb.log({ "epoch": epoch,
                        "train_loss": train_loss,
                        "test_loss": test_loss,
                        "train_acc": train_acc,
                        "test_acc": test_acc,
                         })
        print(f'初期状態：損失：{self.history[0,3]:.5f}  精度：{self.history[0,4]:.5f}')
        print(f'最終状態：損失：{self.history[-1,3]:.5f}  精度：{self.history[-1,4]:.5f}')
            

In [None]:
#パラメータ読み込み
def yaml_read(yaml_file):
    with open(yaml_file) as f:
        cfg = yaml.safe_load(f)
    return cfg

In [None]:
ExecIns = ExecClass()
#ExecIns.main()
#ExecIns.plot_loss_history()
#ExecIns.plot_acc_history()

In [None]:
def sweep():
  sweep_config = yaml_read("config_sweep.yaml")
  sweep_id = wandb.sweep(sweep_config, project="test_sweep")
  wandb.agent(sweep_id, ExecIns.main)

In [None]:
sweep()