## Carregamento dos dados

In [1]:
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

X_ = torch.load('../data/processed/X_.pt')
y_ = torch.load('../data/processed/y_.pt')

real_set = TensorDataset(X_, y_)

X_0 = torch.load('../data/processed/X_0.pt')
X_1 = torch.load('../data/processed/X_1.pt')
X_2 = torch.load('../data/processed/X_2.pt')
X_3 = torch.load('../data/processed/X_3.pt')

X_0 = torch.from_numpy(np.expand_dims(X_0, axis=1))
X_1 = torch.from_numpy(np.expand_dims(X_1, axis=1))
X_2 = torch.from_numpy(np.expand_dims(X_2, axis=1))
X_3 = torch.from_numpy(np.expand_dims(X_3, axis=1))

gen_set_0 = TensorDataset(X_0, torch.as_tensor(np.full(1296, 0)))
gen_set_1 = TensorDataset(X_1, torch.as_tensor(np.full(1296, 1)))
gen_set_2 = TensorDataset(X_2, torch.as_tensor(np.full(1296, 2)))
gen_set_3 = TensorDataset(X_3, torch.as_tensor(np.full(1296, 3)))

print(X_0.shape)
print(X_1.shape)
print(X_2.shape)
print(X_3.shape)
print(X_.shape)
print(y_.shape)

  from .autonotebook import tqdm as notebook_tqdm


torch.Size([1296, 1, 22, 1125])
torch.Size([1296, 1, 22, 1125])
torch.Size([1296, 1, 22, 1125])
torch.Size([1296, 1, 22, 1125])
torch.Size([5184, 22, 1125])
torch.Size([5184])


## Definição do modelo da GAN

In [2]:
from torch import flatten
from torch import nn

In [3]:
class Generator(nn.Module):
    def __init__(self, inputDim=100, outputChannels=1):
        super(Generator, self).__init__()
        
        self.ct1 = nn.ConvTranspose2d(in_channels=inputDim, out_channels=128, kernel_size=(1, 140), stride=2, padding=0, bias=False)
        self.relu1 = nn.ReLU()
        self.batchNorm1 = nn.BatchNorm2d(128)
        
        self.ct2 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(4, 4), stride=2, padding=0, bias=False)
        self.relu2 = nn.ReLU()
        self.batchNorm2 = nn.BatchNorm2d(64)
        
        self.ct3 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(4, 1), stride=2, padding=0, bias=False)
        self.relu3 = nn.ReLU()
        self.batchNorm3 = nn.BatchNorm2d(32)
        
        self.ct4 = nn.ConvTranspose2d(in_channels=32, out_channels=outputChannels, kernel_size=(4, 1), stride=2, padding=0, bias=False)
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        x = self.ct1(x)
        x = self.relu1(x)
        x = self.batchNorm1(x)
        
        x = self.ct2(x)
        x = self.relu2(x)
        x = self.batchNorm2(x)
        
        x = self.ct3(x)
        x = self.relu3(x)
        x = self.batchNorm3(x)
        
        x = self.ct4(x)
        output = self.tanh(x)
        
        return output

In [4]:
class Discriminator(nn.Module):
	def __init__(self, depth, alpha=0.2):
		super(Discriminator, self).__init__()
        
		self.conv1 = nn.Conv2d(in_channels=depth, out_channels=32, kernel_size=4, stride=2, padding=1)
		self.leakyRelu1 = nn.LeakyReLU(alpha, inplace=True)
        
		self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2, padding=1)
		self.leakyRelu2 = nn.LeakyReLU(alpha, inplace=True)
        
		self.fc1 = nn.Linear(in_features=89920, out_features=512)
		self.leakyRelu3 = nn.LeakyReLU(alpha, inplace=True)
        
		self.fc2 = nn.Linear(in_features=512, out_features=1)
		self.sigmoid = nn.Sigmoid()

	def forward(self, x):
		x = self.conv1(x)
		x = self.leakyRelu1(x)
        
		x = self.conv2(x)
		x = self.leakyRelu2(x)
        
		x = flatten(x, 1)
		x = self.fc1(x)
		x = self.leakyRelu3(x)
        
		x = self.fc2(x)
		output = self.sigmoid(x)
        
		return output

In [5]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 20
BATCH_SIZE = 128

In [6]:
dataloader_0 = DataLoader(dataset=gen_set_0, batch_size=128)
dataloader_1 = DataLoader(dataset=gen_set_1, batch_size=128)
dataloader_2 = DataLoader(dataset=gen_set_2, batch_size=128)
dataloader_3 = DataLoader(dataset=gen_set_3, batch_size=128)

In [7]:
from torch.optim import Adam
from torch.nn import BCELoss

stepsPerEpoch = len(dataloader_0.dataset) // BATCH_SIZE

print("[INFO] building generator...")
gen = Generator(inputDim=100, outputChannels=1)
gen.to(DEVICE)

print("[INFO] building discriminator...")
disc = Discriminator(depth=1)
disc.to(DEVICE)

genOpt = Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.0002 / NUM_EPOCHS)
discOpt = Adam(disc.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.0002 / NUM_EPOCHS)

criterion = BCELoss()

print("[INFO] starting training...")
benchmarkNoise = torch.randn(1296, 100, 1, 1, device=DEVICE)

realLabel = 1
fakeLabel = 0

[INFO] building generator...
[INFO] building discriminator...
[INFO] starting training...


## Geração dos dados sintéticos

In [8]:
for epoch in range(NUM_EPOCHS):
    print("[INFO] starting epoch {} of {}...".format(epoch + 1, NUM_EPOCHS))
    
    epochLossG = 0
    epochLossD = 0
    for x in dataloader_0:
        disc.zero_grad()
        
        images = x[0]
        images = images.to(DEVICE)
        
        bs =  images.size(0)
        labels = torch.full((bs,), realLabel, dtype=torch.float, device=DEVICE)
        
        output = disc(images).view(-1)
        
        errorReal = criterion(output, labels)
        
        errorReal.backward()
        
        noise = torch.randn(bs, 100, 1, 1, device=DEVICE)
        
        fake_0 = gen(noise)
        labels.fill_(fakeLabel)
        
        output = disc(fake_0.detach()).view(-1)
        errorFake = criterion(output, labels)
        
        errorFake.backward()
        
        errorD = errorReal + errorFake
        discOpt.step()
        
        gen.zero_grad()
        
        labels.fill_(realLabel)
        output = disc(fake_0).view(-1)
        
        errorG = criterion(output, labels)
        errorG.backward()
        
        genOpt.step()
        
        epochLossD += errorD
        epochLossG += errorG
        
    print("[INFO] Generator Loss: {:.4f}, Discriminator Loss: {:.4f}".format(epochLossG / stepsPerEpoch, epochLossD / stepsPerEpoch))
    if (epoch + 1) % 2 == 0:
        gen.eval()
        fake_0 = gen(benchmarkNoise)
        gen.train()

[INFO] starting epoch 1 of 20...
[INFO] Generator Loss: 6.5883, Discriminator Loss: 0.5440
[INFO] starting epoch 2 of 20...
[INFO] Generator Loss: 6.9375, Discriminator Loss: 0.0082
[INFO] starting epoch 3 of 20...
[INFO] Generator Loss: 6.3571, Discriminator Loss: 0.0169
[INFO] starting epoch 4 of 20...
[INFO] Generator Loss: 8.2464, Discriminator Loss: 0.0015
[INFO] starting epoch 5 of 20...
[INFO] Generator Loss: 8.6315, Discriminator Loss: 0.0011
[INFO] starting epoch 6 of 20...
[INFO] Generator Loss: 8.6627, Discriminator Loss: 0.0009
[INFO] starting epoch 7 of 20...
[INFO] Generator Loss: 9.6442, Discriminator Loss: 0.0004
[INFO] starting epoch 8 of 20...
[INFO] Generator Loss: 10.1951, Discriminator Loss: 0.0003
[INFO] starting epoch 9 of 20...
[INFO] Generator Loss: 10.6040, Discriminator Loss: 0.0002
[INFO] starting epoch 10 of 20...
[INFO] Generator Loss: 10.4195, Discriminator Loss: 0.0002
[INFO] starting epoch 11 of 20...
[INFO] Generator Loss: 10.4494, Discriminator Loss: 

In [9]:
print(fake_0.shape)

torch.Size([1296, 1, 22, 1125])


In [10]:
size=(22, 1125)
fake_0 = fake_0.detach().cpu().view(-1, *size)
print(fake_0.shape)

torch.Size([1296, 22, 1125])


In [11]:
import gc

gc.collect()
torch.cuda.empty_cache()

In [12]:
for epoch in range(NUM_EPOCHS):
    print("[INFO] starting epoch {} of {}...".format(epoch + 1, NUM_EPOCHS))
    
    epochLossG = 0
    epochLossD = 0
    for x in dataloader_1:
        disc.zero_grad()
        
        images = x[0]
        images = images.to(DEVICE)
        
        bs =  images.size(0)
        labels = torch.full((bs,), realLabel, dtype=torch.float, device=DEVICE)
        
        output = disc(images).view(-1)
        
        errorReal = criterion(output, labels)
        
        errorReal.backward()
        
        noise = torch.randn(bs, 100, 1, 1, device=DEVICE)
        
        fake_1 = gen(noise)
        labels.fill_(fakeLabel)
        
        output = disc(fake_1.detach()).view(-1)
        errorFake = criterion(output, labels)
        
        errorFake.backward()
        
        errorD = errorReal + errorFake
        discOpt.step()
        
        gen.zero_grad()
        
        labels.fill_(realLabel)
        output = disc(fake_1).view(-1)
        
        errorG = criterion(output, labels)
        errorG.backward()
        
        genOpt.step()
        
        epochLossD += errorD
        epochLossG += errorG
        
    print("[INFO] Generator Loss: {:.4f}, Discriminator Loss: {:.4f}".format(epochLossG / stepsPerEpoch, epochLossD / stepsPerEpoch))
    if (epoch + 1) % 2 == 0:
        gen.eval()
        fake_1 = gen(benchmarkNoise)
        gen.train()

[INFO] starting epoch 1 of 20...
[INFO] Generator Loss: 11.2221, Discriminator Loss: 0.0001
[INFO] starting epoch 2 of 20...
[INFO] Generator Loss: 10.7367, Discriminator Loss: 0.0001
[INFO] starting epoch 3 of 20...
[INFO] Generator Loss: 10.5085, Discriminator Loss: 0.0001
[INFO] starting epoch 4 of 20...
[INFO] Generator Loss: 10.4779, Discriminator Loss: 0.0001
[INFO] starting epoch 5 of 20...
[INFO] Generator Loss: 10.4281, Discriminator Loss: 0.0001
[INFO] starting epoch 6 of 20...
[INFO] Generator Loss: 10.8077, Discriminator Loss: 0.0001
[INFO] starting epoch 7 of 20...
[INFO] Generator Loss: 11.1647, Discriminator Loss: 0.0001
[INFO] starting epoch 8 of 20...
[INFO] Generator Loss: 11.4017, Discriminator Loss: 0.0001
[INFO] starting epoch 9 of 20...
[INFO] Generator Loss: 11.6652, Discriminator Loss: 0.0000
[INFO] starting epoch 10 of 20...
[INFO] Generator Loss: 11.8915, Discriminator Loss: 0.0000
[INFO] starting epoch 11 of 20...
[INFO] Generator Loss: 12.0710, Discriminator

In [13]:
print(fake_1.shape)

torch.Size([1296, 1, 22, 1125])


In [14]:
fake_1 = fake_1.detach().cpu().view(-1, *size)
print(fake_1.shape)

torch.Size([1296, 22, 1125])


In [15]:
gc.collect()
torch.cuda.empty_cache()

In [16]:
for epoch in range(NUM_EPOCHS):
    print("[INFO] starting epoch {} of {}...".format(epoch + 1, NUM_EPOCHS))
    
    epochLossG = 0
    epochLossD = 0
    for x in dataloader_2:
        disc.zero_grad()
        
        images = x[0]
        images = images.to(DEVICE)
        
        bs =  images.size(0)
        labels = torch.full((bs,), realLabel, dtype=torch.float, device=DEVICE)
        
        output = disc(images).view(-1)
        
        errorReal = criterion(output, labels)
        
        errorReal.backward()
        
        noise = torch.randn(bs, 100, 1, 1, device=DEVICE)
        
        fake_2 = gen(noise)
        labels.fill_(fakeLabel)
        
        output = disc(fake_2.detach()).view(-1)
        errorFake = criterion(output, labels)
        
        errorFake.backward()
        
        errorD = errorReal + errorFake
        discOpt.step()
        
        gen.zero_grad()
        
        labels.fill_(realLabel)
        output = disc(fake_2).view(-1)
        
        errorG = criterion(output, labels)
        errorG.backward()
        
        genOpt.step()
        
        epochLossD += errorD
        epochLossG += errorG
        
    print("[INFO] Generator Loss: {:.4f}, Discriminator Loss: {:.4f}".format(epochLossG / stepsPerEpoch, epochLossD / stepsPerEpoch))
    if (epoch + 1) % 2 == 0:
        gen.eval()
        fake_2 = gen(benchmarkNoise)
        gen.train()

[INFO] starting epoch 1 of 20...
[INFO] Generator Loss: 12.4943, Discriminator Loss: 0.0000
[INFO] starting epoch 2 of 20...
[INFO] Generator Loss: 12.7471, Discriminator Loss: 0.0000
[INFO] starting epoch 3 of 20...
[INFO] Generator Loss: 12.9096, Discriminator Loss: 0.0000
[INFO] starting epoch 4 of 20...
[INFO] Generator Loss: 13.0524, Discriminator Loss: 0.0000
[INFO] starting epoch 5 of 20...
[INFO] Generator Loss: 13.1832, Discriminator Loss: 0.0000
[INFO] starting epoch 6 of 20...
[INFO] Generator Loss: 13.3301, Discriminator Loss: 0.0000
[INFO] starting epoch 7 of 20...
[INFO] Generator Loss: 13.4495, Discriminator Loss: 0.0000
[INFO] starting epoch 8 of 20...
[INFO] Generator Loss: 13.5171, Discriminator Loss: 0.0000
[INFO] starting epoch 9 of 20...
[INFO] Generator Loss: 13.5830, Discriminator Loss: 0.0000
[INFO] starting epoch 10 of 20...
[INFO] Generator Loss: 13.6201, Discriminator Loss: 0.0000
[INFO] starting epoch 11 of 20...
[INFO] Generator Loss: 13.7459, Discriminator

In [17]:
print(fake_2.shape)

torch.Size([1296, 1, 22, 1125])


In [18]:
fake_2 = fake_2.detach().cpu().view(-1, *size)
print(fake_2.shape)

torch.Size([1296, 22, 1125])


In [19]:
gc.collect()
torch.cuda.empty_cache()

In [20]:
for epoch in range(NUM_EPOCHS):
    print("[INFO] starting epoch {} of {}...".format(epoch + 1, NUM_EPOCHS))
    
    epochLossG = 0
    epochLossD = 0
    for x in dataloader_3:
        disc.zero_grad()
        
        images = x[0]
        images = images.to(DEVICE)
        
        bs =  images.size(0)
        labels = torch.full((bs,), realLabel, dtype=torch.float, device=DEVICE)
        
        output = disc(images).view(-1)
        
        errorReal = criterion(output, labels)
        
        errorReal.backward()
        
        noise = torch.randn(bs, 100, 1, 1, device=DEVICE)
        
        fake_3 = gen(noise)
        labels.fill_(fakeLabel)
        
        output = disc(fake_3.detach()).view(-1)
        errorFake = criterion(output, labels)
        
        errorFake.backward()
        
        errorD = errorReal + errorFake
        discOpt.step()
        
        gen.zero_grad()
        
        labels.fill_(realLabel)
        output = disc(fake_3).view(-1)
        
        errorG = criterion(output, labels)
        errorG.backward()
        
        genOpt.step()
        
        epochLossD += errorD
        epochLossG += errorG
        
    print("[INFO] Generator Loss: {:.4f}, Discriminator Loss: {:.4f}".format(epochLossG / stepsPerEpoch, epochLossD / stepsPerEpoch))
    if (epoch + 1) % 2 == 0:
        gen.eval()
        fake_3 = gen(benchmarkNoise)
        gen.train()

[INFO] starting epoch 1 of 20...
[INFO] Generator Loss: 14.5325, Discriminator Loss: 0.0000
[INFO] starting epoch 2 of 20...
[INFO] Generator Loss: 14.5642, Discriminator Loss: 0.0000
[INFO] starting epoch 3 of 20...
[INFO] Generator Loss: 14.6097, Discriminator Loss: 0.0000
[INFO] starting epoch 4 of 20...
[INFO] Generator Loss: 14.6546, Discriminator Loss: 0.0000
[INFO] starting epoch 5 of 20...
[INFO] Generator Loss: 14.6978, Discriminator Loss: 0.0000
[INFO] starting epoch 6 of 20...
[INFO] Generator Loss: 14.7417, Discriminator Loss: 0.0000
[INFO] starting epoch 7 of 20...
[INFO] Generator Loss: 14.7856, Discriminator Loss: 0.0000
[INFO] starting epoch 8 of 20...
[INFO] Generator Loss: 14.8293, Discriminator Loss: 0.0000
[INFO] starting epoch 9 of 20...
[INFO] Generator Loss: 14.8658, Discriminator Loss: 0.0000
[INFO] starting epoch 10 of 20...
[INFO] Generator Loss: 14.9002, Discriminator Loss: 0.0000
[INFO] starting epoch 11 of 20...
[INFO] Generator Loss: 14.9376, Discriminator

In [21]:
print(fake_3.shape)

torch.Size([1296, 1, 22, 1125])


In [22]:
fake_3 = fake_3.detach().cpu().view(-1, *size)
print(fake_3.shape)

torch.Size([1296, 22, 1125])


In [23]:
gen.cpu()
disc.cpu()
del gen, disc
gc.collect()
torch.cuda.empty_cache()

## Processamento dos dados sintéticos para classificação

In [25]:
fake = torch.cat((fake_0, fake_1, fake_2, fake_3), 0)
y_fake = torch.cat((torch.as_tensor(np.full(1296, 0)), torch.as_tensor(np.full(1296, 1)), torch.as_tensor(np.full(1296, 2)), torch.as_tensor(np.full(1296, 3))), 0)
print(fake.shape)
print(y_fake.shape)

torch.Size([5184, 22, 1125])
torch.Size([5184])


In [26]:
fake_set = TensorDataset(fake, y_fake)

In [27]:
from torch.utils.data import random_split

fake_fulltrainset, fake_evalset = random_split(fake_set, [2592, 2592])
fake_trainset, fake_testset = random_split(fake_fulltrainset, [1728, 864])
real_fulltrainset, real_evalset = random_split(real_set, [2592, 2592])
real_trainset, real_testset = random_split(real_fulltrainset, [1728, 864])

## Definição do modelo do classificador

In [28]:
from braindecode.util import set_random_seeds
from braindecode.models import EEGNetv4

cuda = torch.cuda.is_available()
device = 'cuda' if cuda else 'cpu'
if cuda:
    torch.backends.cudnn.benchmark = False

seed = 20200220
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
n_chans = 22
input_window_samples = 1125
F1, D = 4, 2
kernel_length = 64

model = EEGNetv4(
    n_chans,
    n_classes,
    input_window_samples=input_window_samples,
    final_conv_length='auto',
    F1=8,
    D=2,
    F2=F1*D,
    kernel_length=kernel_length,
    drop_prob=0.5
)
model.to(device);

## Treinamento do classificador com dados reais

In [29]:
from skorch.helper import predefined_split
from skorch.callbacks import LRScheduler
from braindecode import EEGClassifier

batch_size = 32
n_epochs = 50

real_clf = EEGClassifier(
    model,
    criterion=torch.nn.NLLLoss,
    optimizer=torch.optim.Adam,
    train_split=predefined_split(real_testset),
    batch_size=batch_size,
    callbacks=[
        "accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
)
real_clf.fit(real_trainset, y=None, epochs=n_epochs);

  epoch    train_accuracy    train_loss    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  ----------------  ------------  ------  ------
      1            [36m0.3791[0m        [32m1.4077[0m            [35m0.3345[0m        [31m1.3455[0m  0.0100  0.7004
      2            [36m0.3981[0m        [32m1.3305[0m            [35m0.3519[0m        1.3692  0.0100  0.6654
      3            [36m0.4728[0m        [32m1.2907[0m            [35m0.4039[0m        [31m1.2703[0m  0.0100  0.6891
      4            0.4363        [32m1.2580[0m            0.3877        1.4141  0.0099  0.6702
      5            [36m0.4896[0m        [32m1.2370[0m            [35m0.4282[0m        1.2771  0.0098  0.6630
      6            [36m0.5272[0m        [32m1.2070[0m            [35m0.4664[0m        [31m1.1793[0m  0.0097  0.6683
      7            [36m0.5353[0m        [32m1.1885[0m            [35m0.5058[0m        [31m1.1672[0m  0.0096  0.6665
 

In [30]:
print(f"Mean Accuracy: {np.mean(real_clf.predict(real_evalset) == [y for X,y in real_evalset])*100:.2f}%")

Mean Accuracy: 59.38%


In [31]:
print(f"Mean Accuracy: {np.mean(real_clf.predict(fake_evalset) == [y for X,y in fake_evalset])*100:.2f}%")

Mean Accuracy: 25.04%


## Treinamento do classificador com dados sintéticos

In [32]:
fake_clf = EEGClassifier(
    model,
    criterion=torch.nn.NLLLoss,
    optimizer=torch.optim.Adam,
    train_split=predefined_split(fake_testset),
    batch_size=batch_size,
    callbacks=[
        "accuracy", ("lr_scheduler", LRScheduler('CosineAnnealingLR', T_max=n_epochs - 1)),
    ],
    device=device,
)
fake_clf.fit(fake_trainset, y=None, epochs=n_epochs);

  epoch    train_accuracy    train_loss    valid_accuracy    valid_loss      lr     dur
-------  ----------------  ------------  ----------------  ------------  ------  ------
      1            [36m0.2419[0m        [32m0.1274[0m            [35m0.2662[0m        [31m2.6020[0m  0.0100  0.7219
      2            0.2419        [32m0.0004[0m            0.2662        4.2445  0.0100  0.6852
      3            0.2419        0.0006            0.2662        4.5083  0.0100  0.6943
      4            0.2419        [32m0.0003[0m            0.2662        4.4194  0.0099  0.6826
      5            [36m0.4821[0m        [32m0.0001[0m            [35m0.5347[0m        3.8980  0.0098  0.6734
      6            0.4821        0.0001            0.5347        3.2590  0.0097  0.6748
      7            0.4821        [32m0.0001[0m            0.5347        [31m2.4245[0m  0.0096  0.6756
      8            [36m0.7459[0m        0.0001            [35m0.7558[0m        [31m1.4161[0m  0.0095  0

In [33]:
print(f"Mean Accuracy: {np.mean(fake_clf.predict(real_evalset) == [y for X,y in real_evalset])*100:.2f}%")

Mean Accuracy: 25.73%


In [34]:
print(f"Mean Accuracy: {np.mean(fake_clf.predict(fake_evalset) == [y for X,y in fake_evalset])*100:.2f}%")

Mean Accuracy: 100.00%


## Distância euclidiana entre os dados reais e sintéticos

In [35]:
real = X_.view(-1, *size)

In [36]:
torch.norm(real-fake, 2)

tensor(44012.6172)