In [1]:
import math
import numpy as np

import pandas as pd 
import scipy.io as io
import os 

from tqdm import tqdm

import torch 
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split

from torch.utils.tensorboard import SummaryWriter

import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib as mpl

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def same_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.deterministic = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    
def train_valid_split(data_set, valid_ratio, seed):
    valid_dataset_size = int(len(data_set) * valid_ratio)
    train_dataset_size = len(data_set) - valid_dataset_size
    train_dataset, valid_dataset = random_split(data_set, [train_dataset_size, valid_dataset_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_dataset), np.array(valid_dataset)

def predict(test_loader, model, device):
    model.eval()
    preds = torch.Tensor([])
    for x in tqdm(test_loader):
        x = x.float().to(device)
        with torch.no_grad():
            pred = model(x)
            preds = torch.concat((preds, pred.detach().to('cpu')))
    return preds

def splitResp(RespData, stride=6):
    repeatNum, frameNum, cellNum = RespData.shape
    result = np.empty([0, stride, cellNum])
    for i in range(repeatNum):
        for j in range(frameNum-stride):
            result = np.concatenate((result, RespData[i:i+1, j:j+stride, :]))
    
    return result

In [3]:
class RespData(Dataset):

    def __init__(self, x):
        self.x = x
    
    def __getitem__(self, idx):
        return self.x[idx]
    
    def __len__(self):
        return len(self.x)
    

In [4]:
class LSTMAE(nn.Module):

    def __init__(self, inputDim, hidden_layer=32):
        super(LSTMAE, self).__init__()
        # self.batch_size = batch_size
        self.hidden_layer = 32

        self.encoder = nn.LSTM(inputDim, self.hidden_layer, batch_first=True)
        self.decoder = nn.LSTM(self.hidden_layer, inputDim, batch_first=True)

    def forward(self, x):
        x, (h, c) = self.encoder(x)
        x, (h, c) = self.decoder(x)
        return x

In [5]:
def trainer(train_loader, valid_loader, model, config, device):
    creterion = nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config['step_size'], gamma=config['gamma'])

    writer = SummaryWriter()

    if not os.path.isdir('LSTMAE'):
        os.mkdir('./LSTMAE')
    
    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], np.inf, 0, 0


    for epoch in range(n_epochs):    
        model.train()
        loss_record = []

        train_qbar = tqdm(train_loader, position=0, leave=True)

        for x in train_qbar:
            optimizer.zero_grad()
            x = x.float().to(device)
            prediction = model(x)
            # print(prediction.shape)
            loss = creterion(prediction, x)
            loss.backward()
            optimizer.step()
            scheduler.step()
            step += 1
            loss_record.append(loss)

            train_qbar.set_description(f'Epoch [{epoch}/{n_epochs}]')
            train_qbar.set_postfix({'loss': loss.detach().item()})
        
        mean_train_loss = sum(loss_record) / len(loss_record)
        writer.add_scalar('LSTMAE_Loss/train', mean_train_loss, step)

        loss_record = []
        model.eval()
        for x in valid_loader:
            x = x.float().to(device)
            with torch.no_grad():
                pred = model(x)
                loss = creterion(pred, x)
            loss_record.append(loss)

        mean_valid_loss = sum(loss_record) / len(loss_record)
        print(f'Epoch [{epoch+1} / {n_epochs}]: Tain loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')
        writer.add_scalar('LSTMAE_Loss/valid', mean_valid_loss, step)

        if mean_valid_loss < best_loss:
            best_loss = mean_valid_loss
            torch.save(model.state_dict(), config['best_model'])
            print(f'Saveing model with loss {best_loss:.3f} ...')
            early_stop_count = 0
        else:
            early_stop_count += 1
        
        if early_stop_count >= config['early_stop']:
            print('\nModel is not imporving, so we halt the training session')
            return
        
        torch.save(model.state_dict(), config['last_model'])
    writer.flush()
    writer.close()
            

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

config = {
    'learning_rate': 1e-3,
    'seed': 122914,
    'valid_ratio': 0.1,
    'early_stop': 400,
    'n_epochs': 10000,
    'best_model': './LSTMAE/lstm_best_0811.ckpt',
    'last_model': './LSTMAE/lstm_last_0811.ckpt',
    'step_size': 1000,
    'gamma': 0.99,
}


In [7]:
PDG_data = np.load('./myData/Day0_PDG/PDG_mouse1_LSTMAE.npy')

In [8]:
cellNumPDG = 371

train_data, valid_data = train_valid_split(PDG_data, config['valid_ratio'], config['seed'])

train_dataset, valid_dataset = RespData(train_data), RespData(valid_data)

train_loader, valid_loader = DataLoader(train_dataset, batch_size=16), DataLoader(valid_dataset, batch_size=16)

In [9]:
model = LSTMAE(inputDim=cellNumPDG).to(device)
# print(next(model.parameters()).device)
trainer(train_loader, valid_loader, model, config, device)

Epoch [0/10000]: 100%|██████████| 324/324 [00:02<00:00, 144.81it/s, loss=0.796]


Epoch [1 / 10000]: Tain loss: 0.8485, Valid loss: 0.7889
Saveing model with loss 0.789 ...


Epoch [1/10000]: 100%|██████████| 324/324 [00:01<00:00, 217.60it/s, loss=0.75] 


Epoch [2 / 10000]: Tain loss: 0.7662, Valid loss: 0.7499
Saveing model with loss 0.750 ...


Epoch [2/10000]: 100%|██████████| 324/324 [00:01<00:00, 215.86it/s, loss=0.725]


Epoch [3 / 10000]: Tain loss: 0.7355, Valid loss: 0.7283
Saveing model with loss 0.728 ...


Epoch [3/10000]: 100%|██████████| 324/324 [00:01<00:00, 228.50it/s, loss=0.709]


Epoch [4 / 10000]: Tain loss: 0.7171, Valid loss: 0.7149
Saveing model with loss 0.715 ...


Epoch [4/10000]: 100%|██████████| 324/324 [00:01<00:00, 229.96it/s, loss=0.699]


Epoch [5 / 10000]: Tain loss: 0.7048, Valid loss: 0.7056
Saveing model with loss 0.706 ...


Epoch [5/10000]: 100%|██████████| 324/324 [00:01<00:00, 229.14it/s, loss=0.692]


Epoch [6 / 10000]: Tain loss: 0.6957, Valid loss: 0.6986
Saveing model with loss 0.699 ...


Epoch [6/10000]: 100%|██████████| 324/324 [00:01<00:00, 229.78it/s, loss=0.686]


Epoch [7 / 10000]: Tain loss: 0.6885, Valid loss: 0.6931
Saveing model with loss 0.693 ...


Epoch [7/10000]: 100%|██████████| 324/324 [00:01<00:00, 230.78it/s, loss=0.68] 


Epoch [8 / 10000]: Tain loss: 0.6826, Valid loss: 0.6886
Saveing model with loss 0.689 ...


Epoch [8/10000]: 100%|██████████| 324/324 [00:01<00:00, 222.83it/s, loss=0.675]


Epoch [9 / 10000]: Tain loss: 0.6777, Valid loss: 0.6848
Saveing model with loss 0.685 ...


Epoch [9/10000]:  44%|████▍     | 144/324 [00:00<00:00, 229.42it/s, loss=0.682]


KeyboardInterrupt: 