In [None]:
import torch
from torch import Tensor
from torch import optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset,DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

import argparse
import os
import time
import math
import numpy as np
from tqdm import tqdm
from scipy.io import loadmat, savemat
import logging
import datetime

In [None]:
# Deel_leaning based Spatial_Temporal classification model
class MLPSpatialFilter(nn.Module):
    def __init__(self, num_ROI, num_hidden, activation):
        super(MLPSpatialFilter, self).__init__()
        self.fc11 = nn.Linear(num_ROI, num_ROI)
        self.fc12 = nn.Linear(num_ROI, num_ROI)
        self.fc21 = nn.Linear(num_ROI, num_hidden)
        self.fc22 = nn.Linear(num_hidden, num_hidden)
        self.fc23 = nn.Linear(num_ROI, num_hidden)
        self.value = nn.Linear(num_hidden, num_hidden)
        self.activation = nn.__dict__[activation]()

    def forward(self, x):
        out = dict()
        x = self.activation(self.fc12(self.activation(self.fc11(x))) + x)
        x = self.activation(self.fc22(self.activation(self.fc21(x))) + self.fc23(x))
        out['value'] = self.value(x)
        out['value_activation'] = self.activation(out['value'])
        return out['value_activation']

class PatchEmbedding_Linear(nn.Module):
    # what are the proper parameters set here?
    def __init__(self, in_channels=100, patch_size=4, emb_size=100, seq_length=225):
        # self.patch_size = patch_size
        super().__init__()
        # change the conv2d parameters here
        self.projection = nn.Sequential(
            Rearrange('b (w s) c-> b w (s c)', s=patch_size),
            nn.Linear(patch_size * in_channels, emb_size)
        )
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
        self.positions = nn.Parameter(torch.randn((seq_length // patch_size) + 1, emb_size))

    def forward(self, x: Tensor) -> Tensor:
        b, _, _ = x.shape
        x = self.projection(x)
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
        # prepend the cls token to the input
        x = torch.cat([cls_tokens, x], dim=1)
        # position
        x += self.positions
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x

class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size=100,
                 num_heads=5,
                 drop_p=0.,
                 forward_expansion=4,
                 forward_drop_p=0.):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))

class TransformerEncoder(nn.Sequential):
    def __init__(self, depth=3, **kwargs):
        super().__init__(*[TransformerEncoderBlock(**kwargs) for _ in range(depth)])

class RegressionHead(nn.Sequential):
    def __init__(self, emb_size=100, n_para=10):
        super().__init__()
        self.rghead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_para)
        )
    def forward(self, x):
        out = self.rghead(x)
        return out

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size=100, n_classes=2):
        super().__init__()
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )

    def forward(self, x):
        out = self.clshead(x)
        return out
        
class Regressor(nn.Sequential):
    def __init__(self,
                 num_ROI = 100,
                 num_hidden = 100,
                 activation ='ELU',
                 in_channels=100,
                 patch_size=4,
                 emb_size=100,
                 seq_length=225,
                 depth=5,
                 n_para=10,
                 **kwargs):
        super().__init__(
            MLPSpatialFilter(num_ROI, num_hidden, activation),
            PatchEmbedding_Linear(in_channels, patch_size, emb_size, seq_length),
            TransformerEncoder(depth, emb_size=emb_size, drop_p=0.5, forward_drop_p=0.5, **kwargs),
            ClassificationHead(emb_size, n_para)
        )
    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)
    
class sz_cn_Dataset(Dataset):
    def __init__(self, x=None,y=None):
        self.x = x
        self.y = y

    def __len__(self): 
        return len(self.x)

    def __getitem__(self, idx):
        x = self.x[idx]
        y = self.y[idx]
        return x,y

def trainer(train_loader, model, criterion, optimizer, args_params):
    # args_params: potential parameter inputs, could be "device","logger"
    device = args_params['device']
    logger = args_params['logger']
    # switch to train mode
    model.train()
    train_loss = []
    start_time = time.time()
    for batch_idx, sample_batch in enumerate(train_loader):
        # load data
        data = sample_batch[0].to(torch.float32)
        label = sample_batch[1].to(torch.float32)
        # training process
        optimizer.zero_grad()
        model_output = model(data)
        loss = criterion(model_output, label)
        loss.backward()
        optimizer.step()
        train_loss.append(loss.data.view(1))
        if (batch_idx + 1) % 10 == 0:
            print_s = "batch_idx_{}_time_{}_train_loss_{}".format(batch_idx, time.time() - start_time, train_loss[-1])
            logger.info(print_s)
    train_loss = torch.cat(train_loss).cpu().numpy()
    return train_loss
# END TRAIN


# START VALIDATE FUNC
def validater(val_loader, model, criterion, args_params):
    # switch to evaluate mode
    device = args_params['device']
    model.eval()
    val_loss = []
    accuracy = []
    with torch.no_grad():
        for batch_idx, sample_batch in enumerate(val_loader):
            data = sample_batch[0].to(torch.float32)
            label = sample_batch[1].to(torch.float32)
            model_output = model(data)
            loss = criterion(model_output, label)
            accuracy_temp = np.mean(((model_output > 0.5) == label).cpu().numpy())
            accuracy.append(accuracy_temp)
            val_loss.append(loss.data.view(1))
    val_loss = torch.cat(val_loss).cpu().numpy()
    accuracy_epoch = np.mean(np.array(accuracy))
    print(accuracy_epoch)
    return val_loss, accuracy_epoch
# END VALIDATE

In [None]:
# Model parameters and optimization parameters
net = Regressor(num_ROI = 100, num_hidden = 100, activation ='ELU', in_channels=100,
             patch_size=1, emb_size=100, seq_length=225, depth=2, n_para=1).to(device)
optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=0)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.5)
criterion = torch.nn.MSELoss(reduction='sum')

batch_size = 16
train = 'train_data'
test = 'test_data'
arch = 'Regressor'
device = 'cpu'
lr=0.001
epoch=30
resume=0
workers=0
model_id = 1
ds_all = loadmat(r'ds_dl.mat')
ts_nc_sz = ds_all['ts_nc_sz']
ts_label = ds_all['ts_label']
all_dataset = sz_cn_Dataset(x=ts_nc_sz,y=ts_label)
train_size = int(len(all_dataset) * 0.7)
test_size = len(all_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(all_dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=16,
                        shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=16,
                        shuffle=True, num_workers=0)
result_root = './model_result/{}_the_model'.format(model_id)
if not os.path.exists(result_root):
    os.makedirs(result_root)
# Define logger
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.FileHandler(os.path.join('./outputs_{}.log'.format(arch)))
handler.setLevel(logging.INFO)
logger.addHandler(handler)
logger.info("============================= {} ====================================".format(datetime.datetime.now()))

In [None]:
# Model training
start_epoch = 0
best_result = np.Inf
train_loss = []
test_loss = []
accuracy = []
start_time = time.time()
save = 1
for i in tqdm(range(1, epoch)):
    train_lss_all = trainer(train_loader, net, criterion, optimizer, {'device': device, 'logger': logger})
    # evaluate on validation set
    test_lss_all, accuracy_epoch = validater(test_loader, net, criterion, {'device': device})
    lr_scheduler.step()
    print(epoch, lr_scheduler.get_lr()[0])
    train_loss.extend([np.sum(np.array(train_lss_all)) / batch_size])
    test_loss.extend([np.sum(np.array(test_lss_all)) / batch_size])
    accuracy.append(accuracy_epoch)
    print_s = 'Epoch {}: Time:{:6.2f}, '.format(epoch, time.time() - start_time) + \
              'Train Loss:{:06.5f}'.format(train_loss[-1]) + ', Test Loss:{:06.5f}'.format(test_loss[-1])
    logger.info(print_s)
    print(print_s)
    is_best = test_loss[-1] < best_result
    best_result = min(test_loss[-1], best_result)
    if is_best:
        torch.save({
            'epoch': epoch, 'arch': arch, 'state_dict': net.state_dict(), 'best_result': best_result, 'lr': lr,
            'train': train, 'test': test, 'optimizer': optimizer.state_dict()},
            result_root + '/model_best.pth.tar')
    if save:
        # save checkpoint
        torch.save({
            'epoch': epoch, 'arch': arch, 'state_dict': net.state_dict(), 'best_result': best_result, 'lr': lr,
            'train': train, 'test': test, 'optimizer': optimizer.state_dict()},
            result_root + '/epoch_{}'.format(epoch))
        savemat(result_root + '/train_test_error.mat', {'train_loss': train_loss, 'test_loss': test_loss, 'accuracy': accuracy})
        savemat(result_root + '/train_test_loss_epoch{}.mat'.format(epoch), {'train_loss': train_lss_all, 'test_loss': test_lss_all})
    # END MAIN_TRAIN