In [1]:
import sys
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import datetime
import pywt
import scipy.signal as signal
import pyeeg
import gc

import load_data
import model 
import utils

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import normalize

seed = 1
utils.set_random_seed(seed)
device = torch.device('cuda')

# PATH = '/home/willer/Desktop/Development/Python/dataset/eeg-motor-movementimagery-dataset-1.0.0/files/'
# ndata = np.load(PATH + 'inte_eeg-motormovement_data.npy')
# nlabel = np.load(PATH + 'inte_eeg-motormovement_label.npy')

# ndata  = ndata[:10000].astype(np.float32)
# nlabel = nlabel.reshape(-1, 1)
# nlabel = nlabel[:10000]

In [2]:
train_loader, test_loader = load_data.get_dataloader_graz(batch_size=128)

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary

class ResidualBlock1d(nn.Module):
    def __init__(self, inner_channel, kernel_size=3, stride=1, padding=1, dilation=1):
        super(ResidualBlock1d, self).__init__()
        self.block = nn.Sequential(
            nn.BatchNorm1d(inner_channel),
            nn.ReLU(inplace=True),
            nn.Conv1d(inner_channel, inner_channel, kernel_size, stride, padding, dilation, bias=False),
            nn.BatchNorm1d(inner_channel),
            nn.ReLU(inplace=True),
            nn.Conv1d(inner_channel, inner_channel, kernel_size, stride, padding, dilation, bias=False),
        )

    def forward(self, x):
        out = self.block(x)
        out += x
        return out

class SubConvNet(nn.Module):

    def __init__(self, in_channel=1, out_channel=4, hidden_channel=64):
        super(SubConvNet, self).__init__()

        self.conv = nn.Conv1d(in_channel, hidden_channel, kernel_size=4, stride=2)
        self.block1 = ResidualBlock1d(hidden_channel)
        self.block2 = ResidualBlock1d(hidden_channel)
        self.conb = nn.Conv1d(hidden_channel, out_channel, 1)

    def forward(self, x):
        batch_size = x.shape[0]
        x = self.conv(x)
        x = self.block1(x)
        x = self.block2(x)
        x = self.conb(x)
        x = x.view(batch_size, -1)
        return x


class PretrainNet_T(nn.Module):

    def __init__(
        self,
        in_channel=3,
        sequence_lens=1000,
        time_lens=10,
        hidden_size=64,
        output_size=2,
        layer_size=1,
        bidirectional=True
    ):
        super(PretrainNet_T, self).__init__()

        if sequence_lens % time_lens != 0:
            raise ValueError("Invalid time lens")

        self.in_channel  = in_channel
        self.time_lens   = time_lens
        self.hidden_size = hidden_size
        self.layer_size  = layer_size
        self.window_size = sequence_lens // time_lens
        self.device      = torch.device('cuda')

        self.subconv    = SubConvNet(in_channel=in_channel, out_channel=4)
        self.input_size = self._adaptive_feature_size()

        self.lstm = nn.LSTM(self.input_size, hidden_size, layer_size, bidirectional=bidirectional)
        if bidirectional:
            self.layer_size *= 2
            
        self.fn1  = nn.Linear(hidden_size * self.layer_size, 128)
        self.fn2  = nn.Linear(128, output_size)

    def forward(self, x):

        batch_size = x.shape[0]

        x = x.chunk(self.time_lens, 2)
        x = torch.stack(x, 1)
        x = x.reshape(batch_size * self.time_lens, self.in_channel, self.window_size)

        x = self.subconv(x)
        x = x.view(batch_size, self.time_lens, self.input_size)
        x = x.permute(1, 0, 2)

        h_0 = torch.zeros(self.layer_size, batch_size, self.hidden_size).to(self.device)
        c_0 = torch.zeros(self.layer_size, batch_size, self.hidden_size).to(self.device)
        x, (h_final, c_final) = self.lstm(x, (h_0, c_0))
        # seq, batch, feature
        x = x.permute(1, 2, 0)

        x = F.avg_pool1d(x, self.time_lens)
        x = x.view(batch_size, -1)
        x = F.relu(self.fn1(x), inplace=True)
        x = F.softmax(self.fn2(x), dim=-1)
        return x

    def _adaptive_feature_size(self):
        x = torch.zeros(1, self.in_channel, self.window_size)
        return self.subconv(x).view(-1).shape[0]

In [None]:
import datetime
import model 
import torch.nn.functional as F
from tensorboardX import SummaryWriter

net = PretrainNet_T().to(device)
criterion_cel = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)
writer = SummaryWriter("rl-runs/PretrainNet_T_" + str(datetime.datetime.now()))

epoch = 30
print('<<=== Begin ===>>')
for i in range(epoch):
    train_correct = train_total = 0
    test_correct  = test_total  = 0
    train_loss = test_loss = 0
    train_size = test_size = 0
    
    net.train()
    for input, label in train_loader:
        
        output = net(input)
        
        prediction = torch.argmax(output, 1)
        label = label.view(-1)
        
        loss = criterion_cel(output, label)
        train_loss += loss.item()
        
        train_correct += (prediction == label).sum().float()
        train_total += len(label)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    net.eval()
    with torch.no_grad():
        for input, label in test_loader:
            output = net(input)

            prediction = torch.argmax(output, 1)
            label = label.view(-1)

            loss = criterion_cel(output, label)
            test_loss += loss.item()

            test_correct += (prediction == label).sum().float()
            test_total += len(label)

    if i % 5 == 0:
        print('e', i)

    writer.add_scalar('loss/train', train_loss, i)
    writer.add_scalar('loss/test', test_loss, i)
    writer.add_scalar('accuracy/train', train_correct/train_total, i)
    writer.add_scalar('accuracy/test', test_correct/test_total, i)
writer.close()
print('<<=== Finish ===>>')
torch.save(net.state_dict(), 'model/PretrainNet_' + str(datetime.datetime.now()) +'.pkl')
print('<<=== Param Saved ===>>')

<<=== Begin ===>>
e 0
e 5
e 10
e 15
e 20
e 25
