## Import Libraries

In [1]:
!pip -q install vit_pytorch linformer

In [8]:
import os # system-wide functions
import numpy as np # For numerical computation
import scipy.io as sio# reading matlab files in python
import warnings
warnings.filterwarnings('ignore')
from matplotlib import pyplot as plt # For plotting graphs(Visualization)
import torch
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torchsummary import summary
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from tqdm.notebook import tqdm
from vit_torch.efficient import ViT
import glob
from torch.utils.tensorboard import SummaryWriter
import time
from linformer import Linformer
from vit_torch.ResNet import myResNet
import dask.dataframe as dd

## Load Data

In [2]:
class BpDataset(Dataset):
    def __init__(self, file_list):
        self.file_list = file_list

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        file = self.file_list[idx]
        data = dd.read_parquet(file).compute()
        return np.array(data.ppg_data), np.array(data.norm_abp_data)



Load Datasets

In [5]:
batch_size = 256

train_dir = 'data/filter_norm_data_train_split_'
valid_dir = 'data/filter_norm_data_valid_split_'
#test_dir = 'data/n_10_dim_125_filter_norm/test/*'

train_list = glob.glob(os.path.join(train_dir,'*.parquet'))
train_data = BpDataset(train_list)
train_loader = DataLoader(dataset = train_data, batch_size=batch_size, shuffle=True )

valid_list = glob.glob(os.path.join(valid_dir,'*.parquet'))
valid_data = BpDataset(valid_list)
valid_loader = DataLoader(dataset = valid_data, batch_size=batch_size, shuffle=True )


In [6]:
print(len(train_data), len(train_loader))
print(len(valid_data), len(valid_loader))

172578 675
19196 75


In [7]:
efficient_transformer = Linformer(
    dim=1024,
    seq_len=10+1,  # 224*224/(32*32) patches + 1 cls-token
    depth=12,
    heads=8,
    k=64
)

### Visual Transformer

In [9]:
# model = ViT(
#     sequence_len = 125,
#     num_patches = 10,
#     num_classes = 3,
#     dim = 128,
#     depth = 12,
#     heads = 8,
#     mlp_dim = 256,
#     dropout = 0.1,
#     emb_dropout = 0.1
# ).to("cuda")
model = ViT(
    sequence_len = 125,
    num_patches = 10,
    dim=1024,
    transformer=efficient_transformer
).to("cuda")

### Training

In [10]:
# Training settings
epochs = 1000
lr = 3e-5
gamma = 0.7
seed = 42

# loss function
#criterion = nn.CrossEntropyLoss()
#criterion = nn.MSELoss()
criterion = nn.L1Loss()

# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [11]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

In [12]:
log_path = 'log/class_N_10_sqe_len_125_transform'
if not os.path.exists(log_path):
    os.mkdir(log_path)
logger = SummaryWriter(log_dir=log_path)
model_path = 'model/class_N_10_sqe_len_125_transform'
if not os.path.exists(model_path):
    os.mkdir(model_path)

for epoch in range(epochs):
    epoch_loss = 0

    for data, label in tqdm(train_loader,disable=True):
        data = data.to("cuda")
        label = label.to("cuda")
        output = model(data)
        loss = criterion(output, label)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss / len(train_loader)
    logger.add_scalar("train_loss",epoch_loss.item(),global_step=epoch)

    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data = data.to("cuda")
            label = label.to("cuda")

            val_output = model(data)
            val_loss = criterion(val_output, label)

            epoch_val_loss += val_loss / len(valid_loader)
        logger.add_scalar("valid_loss",epoch_val_loss.item(),global_step=epoch)

    model_name  =   '%s/model_%d.pth' % (model_path, epoch)
    if((epoch+1) % 50 == 0):
        torch.save(model.state_dict(), model_name)
    cur_time    =   time.strftime("%Y%m%d-%H:%M:%S", time.localtime())
    print(
        f"{cur_time} Epoch : {epoch+1}/{epochs} - loss : {epoch_loss:.4f} - val_loss : {epoch_val_loss:.4f}"
    )


20220204-16:31:36 Epoch : 1/1000 - loss : 0.0206 - val_loss : 0.0176
20220204-16:56:04 Epoch : 2/1000 - loss : 0.0164 - val_loss : 0.0166
20220204-17:20:40 Epoch : 3/1000 - loss : 0.0156 - val_loss : 0.0159
20220204-17:45:51 Epoch : 4/1000 - loss : 0.0152 - val_loss : 0.0157
20220204-18:10:08 Epoch : 5/1000 - loss : 0.0151 - val_loss : 0.0155
20220204-18:34:26 Epoch : 6/1000 - loss : 0.0148 - val_loss : 0.0148
20220204-18:59:00 Epoch : 7/1000 - loss : 0.0141 - val_loss : 0.0149
20220204-19:23:27 Epoch : 8/1000 - loss : 0.0136 - val_loss : 0.0136
20220204-19:48:04 Epoch : 9/1000 - loss : 0.0130 - val_loss : 0.0135
20220204-20:12:15 Epoch : 10/1000 - loss : 0.0126 - val_loss : 0.0133
20220204-20:36:51 Epoch : 11/1000 - loss : 0.0123 - val_loss : 0.0126
20220204-21:01:15 Epoch : 12/1000 - loss : 0.0120 - val_loss : 0.0124
20220204-21:25:42 Epoch : 13/1000 - loss : 0.0117 - val_loss : 0.0125
20220204-21:50:08 Epoch : 14/1000 - loss : 0.0115 - val_loss : 0.0124
20220204-22:14:04 Epoch : 15/