In [1]:
import os
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
import numpy as np
import math
import scipy
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn.init as init
from torch.utils.data import Dataset
import torch
import torch.nn.functional as F
from torch.autograd import Variable
from torch import nn
from torch import Tensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from torch.backends import cudnn
cudnn.benchmark = False
cudnn.deterministic = True
import math
from model_utils import SSA, LightweightConv1d, Mixer1D
from TOPA import TOPA_Loss

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#network architecture modified from：SST-DPN https://github.com/hancan16/SST-DPN
class Efficient_Encoder(nn.Module):

    def __init__(
        self,
        samples,
        chans,
        F1=16,
        F2=36,
        time_kernel1=75,
        pool_kernels=[50, 100, 250],
    ):
        super().__init__()

        self.time_conv = LightweightConv1d(
            in_channels=chans,
            num_heads=1,
            depth_multiplier=F1,
            kernel_size=time_kernel1,
            stride=1,
            padding="same",
            bias=True,
            weight_softmax=False,
        )
        self.ssa = SSA(samples, chans * F1)

        self.chanConv=nn.Conv1d(
                chans * F1,
                F2,
                kernel_size=1,
                stride=1,
                padding=0,
            )
        self.batchNorm1d = nn.BatchNorm1d(F2)
        self.dBatchNorm1d = []
        for i in range(9):
            self.dBatchNorm1d += [nn.BatchNorm1d(F2).cuda()]
        self.elu = nn.ELU()

        self.mixer = Mixer1D(dim=F2, kernel_sizes=pool_kernels)

    def forward(self, x,x_domain=None,dBatchNorm=False):

        x = self.time_conv(x)
        # print(x.shape)
        x, _ = self.ssa(x)
        # print(x.shape)
        x_chan = self.chanConv(x)
        if dBatchNorm:
            y = torch.zeros_like(x_chan)
            for i in x_domain.unique():
                x_ = x_chan[x_domain==i]
                # print(x_.shape)
                y[x_domain==i] = self.dBatchNorm1d[i-1](x_)
            x_chan = self.elu(y)
        else:
            x_chan = self.batchNorm1d(x_chan)
        # print(x_chan.shape)
        feature = self.mixer(x_chan)
        # print(feature.shape)

        # feature = self.linear(feature)
        return feature


class EEGEncoder(nn.Module):

    def __init__(
        self,
        chans,
        samples,
        num_classes=4,
        F1=9,
        F2=48,
        time_kernel1=75,
        pool_kernels=[50, 100, 200],
    ):
        super().__init__()
        self.encoder = Efficient_Encoder(
            samples=samples,
            chans=chans,
            F1=F1,
            F2=F2,
            time_kernel1=time_kernel1,
            pool_kernels=pool_kernels,
        )
        self.features = None

        x = torch.ones((1, chans, samples))
        out = self.encoder(x)
        feat_dim = out.shape[-1]
        
        class ClassifyHead(nn.Module):
            def __init__(self):
                super().__init__()
                self.isp = nn.Parameter(torch.randn(num_classes, feat_dim), requires_grad=True)
                nn.init.kaiming_normal_(self.isp)
            def forward(self,x,wog = False):
                if wog:
                    return -torch.cdist(x, self.isp.detach(), p=2)
                else:
                    return -torch.cdist(x, self.isp, p=2)
                # return torch.einsum("bd,cd->bc", x, self.isp)
        self.classifyHead = ClassifyHead()

    def get_features(self):
        if self.features is not None:
            return self.features
        else:
            raise RuntimeError("No features available. Run forward() first.")

    def forward(self, x,x_domain):

        features = self.encoder(x,x_domain,dBatchNorm=True)
        self.features = features
        logits = self.classifyHead(features,wog=False)

        return logits,features

In [3]:
class TLExP():
    def __init__(self, trainSetI,testSetI):
        super(TLExP, self).__init__()
        self.batch_size = 72*2
        self.n_epochs = 2000
        self.c_dim = 4
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.trainSetI = trainSetI
        self.testSetI = testSetI
        self.targetDomainLabel = eval(testSetI[0][0])
        seti = [testSetI[0][0],]
        for i in trainSetI:
            seti.append(i[0])
        self.SetIs = set(seti)
        self.trainNum = len(self.SetIs)

        self.start_epoch = 0
        self.root = 'D:\Projects\DataSet\standard_2a_data/'

        trainIs = ""
        for i in trainSetI:
            trainIs = trainIs+i

        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor

        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()

        self.model = EEGEncoder(chans=22, samples=1000, num_classes=4).cuda()
        self.model = self.model.cuda()
        
    def interaug_TL(self, timg, label):  
        aug_data = []
        aug_label = []

        for cls0 in torch.unique(label[:,0]):
            for cls1 in torch.unique(label[:,1]):
                conbinedCls = torch.tensor((cls0,cls1))
                cls_idx0 = np.where(label[:,0] == conbinedCls[0])
                cls_idx1 = np.where(label[:,1] == conbinedCls[1])
                cls_idx=np.intersect1d(cls_idx0,cls_idx1)
                #print(cls_idx)
                tmp_data = timg[cls_idx]
                tmp_label = label[cls_idx,:]
                tmp_aug_data = np.zeros((int(self.batch_size / (len(torch.unique(label[:,0]))*len(torch.unique(label[:,1])))),22, 1000))
                for ri in range(int(self.batch_size / (len(torch.unique(label[:,0]))*len(torch.unique(label[:,1]))))):
                    for rj in range(8):
                        rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
                        tmp_aug_data[ri, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :,
                                                                      rj * 125:(rj + 1) * 125]
                aug_data.append(tmp_aug_data)
                aug_label.append(tmp_label[0,:].repeat(tmp_aug_data.shape[0],1))

        aug_data = np.concatenate(aug_data)
        aug_label = np.concatenate(aug_label)
        aug_shuffle = np.random.permutation(len(aug_data))
        aug_data = aug_data[aug_shuffle, :, :]
        aug_label = aug_label[aug_shuffle]

        aug_data = torch.from_numpy(aug_data).cuda()
        aug_data = aug_data.float()
        aug_label = torch.from_numpy(aug_label).cuda()
        aug_label = aug_label.long()
        aug_label[:,0] = aug_label[:,0]
        return aug_data, aug_label

    def get_source_data(self,keep_ratio=0.8):
        filename = self.trainSetI[0]
        print('load train data...')
        for i in range(len(self.trainSetI)):
            filename = self.trainSetI[i]
            total_data = scipy.io.loadmat(self.root + 'rawA0'+filename+'.mat')
            if i==0:
                train_data = total_data['data']
                train_label = total_data['label']
                train_seti = np.full_like(total_data['label'],eval(filename[0]))
            else:
                train_data = np.concatenate((train_data,total_data['data']),axis=2)
                train_label = np.concatenate((train_label,total_data['label']))
                train_seti = np.concatenate((train_seti,np.full_like(total_data['label'],eval(filename[0]))))
        print('load test data...')
        for i in range(len(self.testSetI)):
            filename = self.testSetI[i]
            total_data = scipy.io.loadmat(self.root + 'rawA0'+filename+'.mat')
            if i==0:
                test_data = total_data['data']
                test_label = total_data['label']
                test_seti = np.full_like(total_data['label'],eval(filename[0]))
            else:
                test_data = np.concatenate((test_data,total_data['data']),axis=2)
                test_label = np.concatenate((test_label,total_data['label']))
                test_seti = np.concatenate((test_seti,np.full_like(total_data['label'],eval(filename[0]))))

        test_data = np.transpose(test_data, (2, 1, 0))
        train_data = np.transpose(train_data, (2, 1, 0))
        train_label = np.transpose(train_label[:,0])
        train_seti = np.transpose(train_seti[:,0])
        test_label = np.transpose(test_label[:,0])
        test_seti = np.transpose(test_seti[:,0])
        print('data load finished')
        return train_data,train_label,train_seti,test_data,test_label,test_seti


    def train(self):
        self.train_data, self.train_label, self.train_seti, self.test_data, self.test_label, self.test_seti = self.get_source_data()
        
        train_data = torch.from_numpy(self.train_data)
        train_label = torch.from_numpy(self.train_label - 1)
        train_seti = torch.from_numpy(self.train_seti)
        self.train_data = train_data
        self.train_label = torch.stack((train_label,train_seti),dim=1)
        rand_index = torch.randperm(self.train_label.size(0))
        val_ratio = 0.125
        val_index = int(self.train_label.size(0)*val_ratio)
        self.val_data = self.train_data[rand_index[:val_index]]
        self.val_label = self.train_label[rand_index[:val_index]]
        self.train_data = self.train_data[rand_index[val_index:]]
        self.train_label = self.train_label[rand_index[val_index:]]
        dataset = torch.utils.data.TensorDataset(self.train_data, self.train_label)
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)

        test_data = torch.from_numpy(self.test_data)
        test_label = torch.from_numpy(self.test_label - 1)
        test_seti = torch.from_numpy(self.test_seti)
        
        self.test_data = test_data
        self.test_label =torch.stack((test_label,test_seti),dim=1)
        test_dataset = torch.utils.data.TensorDataset(self.test_data,self.test_label)
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size*2, shuffle=True)
        isp_params = self.model.classifyHead.isp
        
        other_params = [param for name, param in self.model.named_parameters() if name != 'classifyHead.isp']
        
        self.optimizer_isp = torch.optim.Adam([isp_params], lr=self.lr*10, betas=(self.b1, self.b2))
        self.optimizer_other = torch.optim.Adam(other_params, lr=self.lr, betas=(self.b1, self.b2))

        test_data = Variable(test_data.type(self.Tensor))
        test_label = Variable(torch.stack((test_label,test_seti),dim=1).type(self.LongTensor))
        
        self.val_data = Variable(self.val_data.type(self.Tensor))
        self.val_label = Variable(self.val_label.type(self.LongTensor))

        bestAcc = 0
        bestAcc_val = 0
        finalAcc = 0
        averAcc = 0
        num = 0
        Y_true = 0
        Y_pred = 0


        self.mcc_ratio = 0.
        for e in range(self.n_epochs):
            self.model.train()
            
            for ee in range(8):
                s_img,s_label = next(iter(self.dataloader))
                s_img = Variable(s_img.cuda().type(self.Tensor))
                s_label = Variable(s_label.cuda().type(self.LongTensor))
                aug_img,aug_label = self.interaug_TL(self.train_data, self.train_label)
                s_img = torch.cat((s_img,aug_img))
                s_label = torch.cat((s_label,aug_label))
                t_img,t_label = next(iter(self.test_dataloader))
                t_img = Variable(t_img.cuda().type(self.Tensor))
                t_label = Variable(t_label.cuda().type(self.LongTensor))
                minlen = min(s_img.shape[0],t_img.shape[0])
                s_img = s_img[:minlen,:,:]
                s_label = s_label[:minlen,:]
                t_img = t_img[:minlen,:,:]
                t_label = t_label[:minlen,:]
                img = torch.cat((s_img.cuda(),t_img.cuda()))
                label= torch.cat((s_label.cuda(),t_label.cuda()))
                domains = label[:,1]
                outputs,feature = self.model(img,domains)
                y_s, y_t = outputs.chunk(2, dim=0)

                
                features_s, features_t = feature.chunk(2, dim=0)
                classfyout = y_s
                trainlabel = s_label

                cls_loss = self.criterion_cls(classfyout, s_label[:,0]) 

                topa_loss = TOPA_Loss(y_t,self.model.classifyHead.isp.T)
                self.topa_ratio = 1-np.exp(1*e/(e-self.n_epochs+1e-9))
                loss = cls_loss + 1*self.topa_ratio*topa_loss 

                self.optimizer_isp.zero_grad()
                self.optimizer_other.zero_grad()
                loss.backward()
                self.optimizer_isp.step()
                self.optimizer_other.step()

            if (e + 1) % 1 == 0:
                self.model.eval()
                with torch.no_grad():
                    Cls,_ = self.model(test_data, test_label[:,1])
    
                    loss_test = self.criterion_cls(Cls, test_label[:,0])
                    y_pred = torch.max(Cls, 1)[1]
                    acc = float((y_pred == test_label[:,0]).cpu().numpy().astype(int).sum()) / float(test_label[:,0].size(0))

                    train_pred = torch.max(classfyout, 1)[1]
                    train_acc = float((train_pred == trainlabel[:,0]).cpu().numpy().astype(int).sum()) / float(trainlabel[:,0].size(0))

                    Cls,_ = self.model(self.val_data, self.val_label[:,1])
    
                    loss_val = self.criterion_cls(Cls, self.val_label[:,0])
                    y_pred = torch.max(Cls, 1)[1]
                    acc_val = float((y_pred == self.val_label[:,0]).cpu().numpy().astype(int).sum()) / float(self.val_label[:,0].size(0))
                    
                    print('Epoch:', e,
                          '  Train loss: %.6f' % loss.detach().cpu().numpy(),
                          '  TOPA_loss: %.6f' % topa_loss.detach().cpu().numpy(),
                          '  cls_loss: %.6f' % cls_loss.detach().cpu().numpy(),
                          '  Test loss: %.6f' % loss_test.detach().cpu().numpy(),
                          '  Train accuracy %.6f' % train_acc,
                          '  val accuracy is %.6f' % acc_val,
                          '  Test accuracy is %.6f' % acc)
    
                    num = num + 1
                    averAcc = averAcc + acc
                    if acc > bestAcc:
                        bestAcc = acc
                    if acc_val > bestAcc_val:
                        finalAcc = acc
                        bestAcc_val = acc_val
                        # torch.save(self.model.state_dict(), 'ckpt/edpnet_'+self.testSetI[0]+'.pth')
                        print('best!')
        averAcc = averAcc / num
        print('The average accuracy is:', averAcc)
        print('The test accuracy is:', finalAcc)
        print('The last accuracy is:', acc)
        print('The best accuracy is:', bestAcc)
        # torch.save(self.model.state_dict(), 'ckpt/edpnet_last_'+self.testSetI[0]+'.pth')

        return bestAcc, averAcc, Y_true, Y_pred

In [None]:
trainSetI = ['2T','3T','4T','5T','6T','7T','8T','9T']
testSetI = ['1E']
exp = TLExP(trainSetI,testSetI)
bestAcc, averAcc, Y_true, Y_pred= exp.train()
torch.cuda.synchronize()
torch.cuda.empty_cache()

load train data...
load test data...
data load finished
Epoch: 0   Train loss: 1.381169   TOPA_loss: 13.408327   cls_loss: 1.381169   Test loss: 1.382989   Train accuracy 0.261029   val accuracy is 0.284722   Test accuracy is 0.256944
best!
