In [1]:
from config import Config
from load_dataset import Dataset
from model.discriminator import Discriminator
from model.generator import Generator

import torch
import torch.nn as nn
from torch.optim import Adam
from tensorboardX import SummaryWriter
from torchvision.utils import save_image
from torchmetrics.classification import BinaryAccuracy
from torch.utils.data import DataLoader, random_split
import os

In [2]:
real_label, fake_label = 0, 1
normal_label, abnormal_label = 0, 1

In [3]:
device = torch.device('cuda')

In [4]:
compute_acc = BinaryAccuracy(threshold=Config.detection_thr).to(device)

In [5]:
gen = Generator().to(device)
dis1 = Discriminator().to(device)
dis2 = Discriminator().to(device)

In [6]:
criterion = nn.BCELoss()

In [7]:
writer = SummaryWriter()

In [8]:
dataset = Dataset('road')

In [9]:
train_size = int(len(dataset) * 0.9)
test_size = len(dataset) - train_size
train_set, test_set = random_split(dataset, [train_size, test_size])

In [10]:
train_dataloader = DataLoader(dataset=train_set, batch_size=Config.batch_size, shuffle=False, num_workers=1)
test_dataloader = DataLoader(dataset=test_set, batch_size=Config.batch_size, shuffle=False, num_workers=1)

In [11]:
optim_G = Adam(gen.parameters(), lr=Config.lr, betas=(Config.b1, Config.b2))
optim_D1 = Adam(dis1.parameters(), lr=Config.lr, betas=(Config.b1, Config.b2))
optim_D2 = Adam(dis2.parameters(), lr=Config.lr, betas=(Config.b1, Config.b2))

In [12]:
def test(epoch):
    
    dis1.eval()
    dis2.eval()
    gen.eval()

    with torch.no_grad():

        img_path = Config.save_path + '/generated_img_samples'.format(epoch)
        if not os.path.exists(img_path):
            os.makedirs(img_path)

        random_x = torch.randn(64, 256, 1, 1).to(device)
        test_sample = gen(random_x).detach().cpu()

        save_image(test_sample[0], '{}/{}.png'.format(img_path, epoch))
        writer.add_image('generated_img_samples', test_sample, epoch, dataformats='NCHW')

        batch_acc = 0

        for batch_idx, (inputs, labels) in enumerate(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            output = dis1(inputs).to(device)

            for out in range(len(output)):
                if output[out] < Config.detection_thr:
                    output[out] = dis2(inputs)[out].to(device)

            output = output.to(device)

            batch_acc += compute_acc(output.to(torch.float32), labels.to(torch.float32))

        epoch_acc = batch_acc / len(test_dataloader)
        print(f'Test accuracy for epoch {epoch}: {epoch_acc}')

        writer.add_scalar('test_acc', epoch_acc, epoch)

In [13]:
def train():

    for epoch in range(Config.epochs):

        # TODO: TRAIN

        dis1.train()
        dis2.train()
        gen.train()

        for batch_idx, (inputs, labels) in enumerate(train_dataloader):

            optim_G.zero_grad()
            optim_D1.zero_grad()
            optim_D2.zero_grad()

            inputs = inputs.to(device)  # batch, 1, 64, 48
            labels = labels.to(device)

            # TODO: train generator

            # labels.fill_(real_label)
            gen_target = torch.zeros(Config.batch_size, requires_grad=False).to(device)

            noise = torch.randn(Config.batch_size, 256, 1, 1).to(device)
            fake_inputs = gen(noise).to(device)

            gen_loss = criterion(dis2(fake_inputs).to(torch.float32), gen_target.to(torch.float32))
            gen_loss.backward()
            optim_G.step()

            # TODO: train first discriminator for normal/abnormal data

            dis1_output = dis1(inputs).to(device)

            dis_1_loss = criterion(dis1_output.to(torch.float32), labels.to(torch.float32))
            dis_1_loss.backward()
            optim_D1.step()

            # TODO: train second discriminator for real/fake data

            # noise = torch.randn(Config.batch_size, 256, 1, 1).to(device)
            # fake_inputs = gen(noise).to(device)

            dis2_real_output = dis2(inputs).to(device)
            real_target = torch.zeros(dis2_real_output.shape[0], requires_grad=False).to(device)

            dis_2_real_loss = criterion(dis2_real_output.to(torch.float32), real_target.to(torch.float32))
            # dis_2_real_loss.backward()

            dis2_fake_output = dis2(fake_inputs.detach())
            fake_target = torch.ones(dis2_fake_output.shape[0], requires_grad=False).to(device)

            dis_2_fake_loss = criterion(dis2_fake_output.to(torch.float32), fake_target.to(torch.float32))
            # dis_2_fake_loss.backward()

            dis_2_total_loss = (dis_2_real_loss + dis_2_fake_loss) / 2
            dis_2_total_loss.backward()

            optim_D2.step()

            writer.add_scalar('loss/dis1_loss', dis_1_loss.data, epoch)

            writer.add_scalar('loss/dis_2_real_loss', dis_2_real_loss.data, epoch)
            writer.add_scalar('loss/dis2_fake_loss', dis_2_fake_loss, epoch)
            writer.add_scalar('loss/dis2_total_loss', dis_2_total_loss, epoch)

            writer.add_scalar('loss/gen_loss', gen_loss.data, epoch)

            gen_path = Config.save_path + '/gen/epoch_{}'.format(epoch)
            if not os.path.exists(gen_path):
                os.makedirs(gen_path)
            dis1_path = Config.save_path + '/dis1/epoch_{}'.format(epoch)
            if not os.path.exists(dis1_path):
                os.makedirs(dis1_path)
            dis2_path = Config.save_path + '/dis2/epoch_{}'.format(epoch)
            if not os.path.exists(dis2_path):
                os.makedirs(dis2_path)

            torch.save(gen.state_dict(), gen_path + '/state_dict.pth')
            torch.save(dis1.state_dict(), dis1_path + '/state_dict.pth')
            torch.save(dis2.state_dict(), dis2_path + '/state_dict.pth')
            torch.save(gen, gen_path + '/model.pth')
            torch.save(dis1, dis1_path + '/model.pth')
            torch.save(dis2, dis2_path + '/model.pth')

            if batch_idx % Config.log_f == 0:
                print("[Train] Epoch: {}/{}, Batch: {}/{}, D1 loss: {}, D2 loss: {}, G loss: {}".format(epoch,
                           Config.epochs, batch_idx, len(train_dataloader), dis_1_loss, dis_2_total_loss, gen_loss))


        # TODO: TEST

        test(epoch)

In [14]:
train()

[Train] Epoch: 0/30, Batch: 0/10929, D1 loss: 0.694366455078125, D2 loss: 0.6918829679489136, G loss: 0.6930721402168274
[Train] Epoch: 0/30, Batch: 3000/10929, D1 loss: 9.530570423521567e-06, D2 loss: 0.6201072931289673, G loss: 0.3713977336883545
[Train] Epoch: 0/30, Batch: 6000/10929, D1 loss: 1.0550903652983834e-06, D2 loss: 0.6315886974334717, G loss: 0.3431209325790405
[Train] Epoch: 0/30, Batch: 9000/10929, D1 loss: 1.7458458501096175e-07, D2 loss: 0.6070866584777832, G loss: 0.3556692898273468
Test accuracy for epoch 0: 0.999974250793457
[Train] Epoch: 1/30, Batch: 0/10929, D1 loss: 5.8665712288075156e-08, D2 loss: 0.5463626980781555, G loss: 0.40931928157806396
[Train] Epoch: 1/30, Batch: 3000/10929, D1 loss: 1.0570586184144304e-08, D2 loss: 0.5491909384727478, G loss: 0.40599238872528076
[Train] Epoch: 1/30, Batch: 6000/10929, D1 loss: 1.9963142250389865e-09, D2 loss: 0.33076557517051697, G loss: 0.7279291749000549
[Train] Epoch: 1/30, Batch: 9000/10929, D1 loss: 4.3362369250

Test accuracy for epoch 14: 0.9396862387657166
[Train] Epoch: 15/30, Batch: 0/10929, D1 loss: 5.956940669904487e-13, D2 loss: 0.6710757613182068, G loss: 0.33585408329963684
[Train] Epoch: 15/30, Batch: 3000/10929, D1 loss: 5.648056891974229e-13, D2 loss: 1.1788508892059326, G loss: 0.09967878460884094
[Train] Epoch: 15/30, Batch: 6000/10929, D1 loss: 5.36701379953014e-13, D2 loss: 0.5625198483467102, G loss: 0.39420291781425476
[Train] Epoch: 15/30, Batch: 9000/10929, D1 loss: 5.499939154281208e-13, D2 loss: 0.5421832799911499, G loss: 0.4128880500793457
Test accuracy for epoch 15: 0.9712191224098206
[Train] Epoch: 16/30, Batch: 0/10929, D1 loss: 5.682806655804562e-13, D2 loss: 0.5792561173439026, G loss: 0.38265708088874817
[Train] Epoch: 16/30, Batch: 3000/10929, D1 loss: 5.3955917424936e-13, D2 loss: 0.5092418789863586, G loss: 0.4491315484046936
[Train] Epoch: 16/30, Batch: 6000/10929, D1 loss: 5.133520561764748e-13, D2 loss: 0.5598098039627075, G loss: 0.3954022228717804
[Train] 

[Train] Epoch: 29/30, Batch: 9000/10929, D1 loss: 3.7590351838201475e-13, D2 loss: 0.5465521216392517, G loss: 0.41044384241104126
Test accuracy for epoch 29: 0.9357202053070068


In [15]:
writer.close()