In [1]:
from config import Config
from load_dataset import Dataset
from model.discriminator import Discriminator
from model.generator import Generator

import torch
import torch.nn as nn
from torch.optim import Adam
from tensorboardX import SummaryWriter
from torchvision.utils import save_image
from torchmetrics.classification import BinaryAccuracy
from torch.utils.data import DataLoader, random_split
import os

In [2]:
real_label, fake_label = 0, 1
normal_label, abnormal_label = 0, 1

In [3]:
device = torch.device('cuda')

In [4]:
compute_acc = BinaryAccuracy(threshold=Config.detection_thr).to(device)

In [5]:
gen = Generator().to(device)
dis1 = Discriminator().to(device)
dis2 = Discriminator().to(device)

In [6]:
criterion = nn.BCELoss()

In [7]:
writer = SummaryWriter()

In [8]:
dataset = Dataset()

In [9]:
train_size = int(len(dataset) * 0.9)
test_size = len(dataset) - train_size
train_set, test_set = random_split(dataset, [train_size, test_size])

In [10]:
train_dataloader = DataLoader(dataset=train_set, batch_size=Config.batch_size, shuffle=False, num_workers=1)
test_dataloader = DataLoader(dataset=test_set, batch_size=Config.batch_size, shuffle=False, num_workers=1)

In [11]:
optim_G = Adam(gen.parameters(), lr=Config.lr, betas=(Config.b1, Config.b2))
optim_D1 = Adam(dis1.parameters(), lr=Config.lr, betas=(Config.b1, Config.b2))
optim_D2 = Adam(dis2.parameters(), lr=Config.lr, betas=(Config.b1, Config.b2))

In [12]:
def test(epoch):
    
    dis1.eval()
    dis2.eval()
    gen.eval()

    with torch.no_grad():

        img_path = Config.save_path + '/generated_img_samples'.format(epoch)
        if not os.path.exists(img_path):
            os.makedirs(img_path)

        random_x = torch.randn(64, 256, 1, 1).to(device)
        test_sample = gen(random_x).detach().cpu()

        save_image(test_sample[0], '{}/{}.png'.format(img_path, epoch))
        writer.add_image('generated_img_samples', test_sample, epoch, dataformats='NCHW')

        batch_acc = 0

        for batch_idx, (inputs, labels) in enumerate(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            output = dis1(inputs).to(device)

            for out in range(len(output)):
                if output[out] < Config.detection_thr:
                    output[out] = dis2(inputs)[out].to(device)

            output = output.to(device)

            batch_acc += compute_acc(output.to(torch.float32), labels.to(torch.float32))

        epoch_acc = batch_acc / len(test_dataloader)
        print(f'Test accuracy for epoch {epoch}: {epoch_acc}')

        writer.add_scalar('test_acc', epoch_acc, epoch)

In [13]:
def train():

    for epoch in range(Config.epochs):

        # TODO: TRAIN

        dis1.train()
        dis2.train()
        gen.train()

        for batch_idx, (inputs, labels) in enumerate(train_dataloader):

            optim_G.zero_grad()
            optim_D1.zero_grad()
            optim_D2.zero_grad()

            inputs = inputs.to(device)  # batch, 1, 64, 48
            labels = labels.to(device)

            # TODO: train generator

            # labels.fill_(real_label)
            gen_target = torch.zeros(Config.batch_size, requires_grad=False).to(device)

            noise = torch.randn(Config.batch_size, 256, 1, 1).to(device)
            fake_inputs = gen(noise).to(device)

            gen_loss = criterion(dis2(fake_inputs).to(torch.float32), gen_target.to(torch.float32))
            gen_loss.backward()
            optim_G.step()

            # TODO: train first discriminator for normal/abnormal data

            dis1_output = dis1(inputs).to(device)

            dis_1_loss = criterion(dis1_output.to(torch.float32), labels.to(torch.float32))
            dis_1_loss.backward()
            optim_D1.step()

            # TODO: train second discriminator for real/fake data

            # noise = torch.randn(Config.batch_size, 256, 1, 1).to(device)
            # fake_inputs = gen(noise).to(device)

            dis2_real_output = dis2(inputs).to(device)
            real_target = torch.zeros(dis2_real_output.shape[0], requires_grad=False).to(device)

            dis_2_real_loss = criterion(dis2_real_output.to(torch.float32), real_target.to(torch.float32))
            # dis_2_real_loss.backward()

            dis2_fake_output = dis2(fake_inputs.detach())
            fake_target = torch.ones(dis2_fake_output.shape[0], requires_grad=False).to(device)

            dis_2_fake_loss = criterion(dis2_fake_output.to(torch.float32), fake_target.to(torch.float32))
            # dis_2_fake_loss.backward()

            dis_2_total_loss = (dis_2_real_loss + dis_2_fake_loss) / 2
            dis_2_total_loss.backward()

            optim_D2.step()

            writer.add_scalar('loss/dis1_loss', dis_1_loss.data, epoch)

            writer.add_scalar('loss/dis_2_real_loss', dis_2_real_loss.data, epoch)
            writer.add_scalar('loss/dis2_fake_loss', dis_2_fake_loss, epoch)
            writer.add_scalar('loss/dis2_total_loss', dis_2_total_loss, epoch)

            writer.add_scalar('loss/gen_loss', gen_loss.data, epoch)

            gen_path = Config.save_path + '/gen/epoch_{}'.format(epoch)
            if not os.path.exists(gen_path):
                os.makedirs(gen_path)
            dis1_path = Config.save_path + '/dis1/epoch_{}'.format(epoch)
            if not os.path.exists(dis1_path):
                os.makedirs(dis1_path)
            dis2_path = Config.save_path + '/dis2/epoch_{}'.format(epoch)
            if not os.path.exists(dis2_path):
                os.makedirs(dis2_path)

            torch.save(gen.state_dict(), gen_path + '/state_dict.pth')
            torch.save(dis1.state_dict(), dis1_path + '/state_dict.pth')
            torch.save(dis2.state_dict(), dis2_path + '/state_dict.pth')
            torch.save(gen, gen_path + '/model.pth')
            torch.save(dis1, dis1_path + '/model.pth')
            torch.save(dis2, dis2_path + '/model.pth')

            if batch_idx % Config.log_f == 0:
                print("[Train] Epoch: {}/{}, Batch: {}/{}, D1 loss: {}, D2 loss: {}, G loss: {}".format(epoch,
                           Config.epochs, batch_idx, len(train_dataloader), dis_1_loss, dis_2_total_loss, gen_loss))


        # TODO: TEST

        test(epoch)

In [14]:
train()

[Train] Epoch: 0/30, Batch: 0/7282, D1 loss: 0.6960949301719666, D2 loss: 0.6928837299346924, G loss: 0.6931027173995972
[Train] Epoch: 0/30, Batch: 3000/7282, D1 loss: 0.004239245317876339, D2 loss: 0.5627424120903015, G loss: 0.39344072341918945
[Train] Epoch: 0/30, Batch: 6000/7282, D1 loss: 0.001236495329067111, D2 loss: 0.5537540912628174, G loss: 0.4013541638851166
Test accuracy for epoch 0: 0.9976466298103333
[Train] Epoch: 1/30, Batch: 0/7282, D1 loss: 0.0007217179518193007, D2 loss: 0.5443126559257507, G loss: 0.4109313488006592
[Train] Epoch: 1/30, Batch: 3000/7282, D1 loss: 0.00237908773124218, D2 loss: 0.5493583083152771, G loss: 0.4057232737541199
[Train] Epoch: 1/30, Batch: 6000/7282, D1 loss: 0.0008084701257757843, D2 loss: 0.5616193413734436, G loss: 0.3935122787952423
Test accuracy for epoch 1: 0.9978009462356567
[Train] Epoch: 2/30, Batch: 0/7282, D1 loss: 0.0005566740874201059, D2 loss: 0.5488399863243103, G loss: 0.4060681462287903
[Train] Epoch: 2/30, Batch: 3000/7

[Train] Epoch: 19/30, Batch: 3000/7282, D1 loss: 0.00018940218433272094, D2 loss: 0.5949111580848694, G loss: 0.3680686354637146
[Train] Epoch: 19/30, Batch: 6000/7282, D1 loss: 0.0001353346451651305, D2 loss: 0.471989244222641, G loss: 0.5098506212234497
Test accuracy for epoch 19: 0.9983025193214417
[Train] Epoch: 20/30, Batch: 0/7282, D1 loss: 0.00030966708436608315, D2 loss: 0.6830312609672546, G loss: 0.30486947298049927
[Train] Epoch: 20/30, Batch: 3000/7282, D1 loss: 0.00016359535220544785, D2 loss: 0.568905770778656, G loss: 0.39439553022384644
[Train] Epoch: 20/30, Batch: 6000/7282, D1 loss: 0.00013210423639975488, D2 loss: 0.593609094619751, G loss: 0.3694118857383728
Test accuracy for epoch 20: 0.9983025193214417
[Train] Epoch: 21/30, Batch: 0/7282, D1 loss: 0.00029348451062105596, D2 loss: 0.459960013628006, G loss: 0.5201161503791809
[Train] Epoch: 21/30, Batch: 3000/7282, D1 loss: 0.00014247040962800384, D2 loss: 0.5550004839897156, G loss: 0.4086458683013916
[Train] Epoc

In [15]:
writer.close()