In [None]:
import csv, os, glob, torch, time
from tqdm import tqdm
from utils.module import write_to_csv
from utils.autoencoder import VAE, vae_loss
from dataloader.dataset import UnlabeledDataset2, UnlabeledTransform2
import torch.utils.data as data
from utils.module import EarlyStopping

In [None]:
!nvidia-smi

In [None]:
img_file_path = sorted(glob.glob('data/Train/images/*'))
img_file_path2 = sorted(glob.glob('data/original_split_resized/*'))
img_list = img_file_path + img_file_path2

train_dataset = UnlabeledDataset2(
    img_list, transform=UnlabeledTransform2(crop_size=32))
train_dataloader = data.DataLoader(
    train_dataset, batch_size=256, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)

In [None]:
# VAEモデルのインスタンス化
model = VAE(latent_dim=2)
model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# オプティマイザを定義
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Epochs
num_epochs = 1000

In [None]:
project = 'vae2'
earlystopping = EarlyStopping(patience=50)
os.makedirs('weights/'+ project, exist_ok=True)
model.load_state_dict(torch.load('weights/vae2/best.pth'))
model = model.to(device)

for epoch in range(num_epochs):
    start_time = time.time()
    model.train()  # モデルをトレーニングモードに設定
    running_loss = 0.0

    for inputs in tqdm(train_dataloader):
        inputs = inputs.to(device)
        optimizer.zero_grad()  # 勾配をゼロに初期化
        outputs, mu, logvar, z = model(inputs)  # フォワードパス
        loss = vae_loss(outputs, inputs, mu, logvar)  # 損失を計算
        loss.backward()  # 逆伝播
        optimizer.step()  # パラメータを更新
        running_loss += loss.item()
    
    epoch_train_loss = running_loss / len(train_dataloader)
    print(f"Epoch {epoch + 1}, Loss: {epoch_train_loss}")

    time_elapsed = time.time() - start_time
    print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    with torch.no_grad():
        # Early Stopping
        earlystopping(epoch_train_loss)

        if earlystopping.early_stop:
            print("Early stopping")
            break
        if earlystopping.counter == 0:
            # download to CPU
            torch.save(model.to('cpu').state_dict(), 'weights/'+ project + '/best.pth')
            # upload to GPU
            model = model.to(device)

        print(f'Early Stopping Counter = {earlystopping.counter}')

torch.save(model.to('cpu').state_dict(), 'weights/'+ project + '/last.pth')

In [None]:
# データセットごとに、modelから潜在変数を出力して二次元空間にプロットする
y, mu, logvar, z = model.forward(x)