In [1]:
import math
import numpy as np

import pandas as pd
import scipy.io as io
import os 
import csv

from tqdm import tqdm 
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from torch.utils.tensorboard import SummaryWriter

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def same_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def train_valid_split(data_set, valid_ratio, seed):
    valid_dataset_size = int(len(data_set) * valid_ratio)
    train_dataset_size = len(data_set) - valid_dataset_size
    train_set, valid_dataset = random_split(data_set, [train_dataset_size, valid_dataset_size], generator=torch.Generator().manual_seed(seed))
    return np.array(train_set), np.array(valid_dataset)

def predict(test_loader, model, device):
    model.eval()
    preds = torch.Tensor([])
    for x in tqdm(test_loader):
        x = x.float().to(device)
        with torch.no_grad():
            pred = model(x)
            preds = torch.concat((preds, pred.detach().to('cpu')))
    # prds = torch.cat(preds, dim=0).numpy()
    return preds

def sparseLoss(model, signal):
    model_children = list(model.children())
    loss = 0
    value = signal
    for child in model_children:
        value = model(signal)
        loss += torch.mean(torch.abs(value))
    return loss

def split_data_FCAE(data, stride=6, withoutGray=False, vedio='PDG'):
    data = torch.tensor(data)
    if withoutGray:
        if vedio == 'PDG':
            data = data.reshape(8, 12, 60, -1) 
            data = data[:, :, :20].reshape((8, 20*12, -1))
        elif vedio == 'MOV':
            data = data[:, 50:]
    repeat, frame, cell = data.shape
    if stride > frame:
        print("stride cannot larger than frame number")
        return None
    repeat_data = data.repeat((1, 1, stride))
    reshape_data = repeat_data.reshape((repeat, frame, stride, cell))
    for i in range(1, stride):
        reshape_data[:, :-i, i] = reshape_data[:, i:, i]
    reshape_data = reshape_data.reshape((repeat*frame, cell*stride))
    return reshape_data.numpy()


In [3]:
class RespData(Dataset):

    def __init__(self, x):
        self.x = x

    def __getitem__(self, idx):
        return self.x[idx]
    
    def __len__(self):
        return len(self.x)

In [4]:
class FCAE(nn.Module):

    def __init__(self, inputDim):
        super(FCAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(inputDim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 16)
        )
        self.decoder = nn.Sequential(
            nn.Linear(16, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, inputDim)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [5]:
def trainer(train_loader, valid_loader, model, config, device):
    criterion = nn.MSELoss(reduction='mean')
    optimizer = torch.optim.Adam(model.parameters(), lr=config['learning_rate'])    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config['step_size'], gamma=config['gamma'])
    best_path = config['best_model'].format(day=time.strftime("%m%d%H%M", time.localtime()), mouseID=config['mouseID'], dataDay=config['day'])
    last_path = config['last_model'].format(day=time.strftime("%m%d%H%M", time.localtime()), mouseID=config['mouseID'], dataDay=config['day'])

    writer = SummaryWriter()

    if not os.path.isdir('./FCAE/'):
        os.mkdir('./FCAE')
    
    n_epochs, best_loss, step, early_stop_count = config['n_epochs'], np.inf, 0, 0

    for epoch in range(n_epochs):
        model.train()
        loss_record = []

        train_pbar = tqdm(train_loader, position=0, leave=True)

        for x in train_pbar:
            optimizer.zero_grad()
            x = x.float().to(device)
            pred = model(x)
            loss = criterion(pred, x)
            if config['useL1']:
                loss += sparseLoss(model, x)

            loss.backward()
            optimizer.step()
            scheduler.step()

            step += 1
            loss_record.append(loss)

            train_pbar.set_description(f'Epoch [{epoch}/{n_epochs}]')
            train_pbar.set_postfix({'loss': loss.detach().item()})
        
        mean_train_loss = sum(loss_record) / len(loss_record)
        writer.add_scalar('FCAE_Loss/train', mean_train_loss, step)

        model.eval()
        loss_record = []
        for x in valid_loader:
            x = x.float().to(device)
            with torch.no_grad():
                pred = model(x)
                loss = criterion(pred, x)
            loss_record.append(loss.item())
        
        mean_valid_loss = sum(loss_record) / len(loss_record)
        print(f'Epoch [{epoch+1}/{n_epochs}]: Train loss: {mean_train_loss:.4f}, Valid loss: {mean_valid_loss:.4f}')
        writer.add_scalar('FCAE_Loss/valid', mean_valid_loss, step)

        if mean_valid_loss < best_loss:
            best_loss = mean_valid_loss
            torch.save(model.state_dict(), best_path)
            print('Saving model with loss {:.3f}...'.format(best_loss))
            early_stop_count = 0
        else:
            early_stop_count += 1
        
        if early_stop_count >= config['early_stop']:
            print('\nModel is not imporving, so we halt the traing session.')
            return 
        
        torch.save(model.state_dict(), last_path)
    writer.flush()
    writer.close()


In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

config = {
    'learning_rate': 1e-4,
    'seed': 914122,
    'valid_ratio': 0.1,
    'n_epochs': 10000,
    'batch_size': 32,
    'early_stop': 400,
    'best_model': './FCAE/best_{day}_mouse{mouseID}_day{dataDay}.ckpt',
    'last_model': './FCAE/last_{day}_mouse{mouseID}_day{dataDay}.ckpt',
    'step_size': 1000,
    'gamma': 0.9999,
    'useL1': True,
    'mouseID': 1,
    'day': 0
}

In [12]:
print(type(device))

<class 'torch.device'>


In [14]:
print(config['best_model'].format(day=time.strftime("%m%d%H%M", time.localtime()), mouseID=config['mouseID'], dataDay=config['day']))

./FCAE/best_08261055_mouse1_day0.ckpt


In [13]:
same_seed(config['seed'])

frame_cell = 6*371

data = np.load('./myData/Day0_PDG/PDG_mouse1_FCAE.npy')

sequenceNum, cellNum = data.shape
data = data.reshape((sequenceNum, cellNum))

train_data, valid_data = train_valid_split(data, config['valid_ratio'], config['seed'])
train_dataset, valid_dataset = RespData(train_data), RespData(valid_data)
print(f'train dataset size: {train_data.shape}')
print(f'valid dataset size: {valid_data.shape}')

train dataset size: (5184, 2226)
valid dataset size: (576, 2226)


In [15]:
train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=config['batch_size'], pin_memory=True)

In [16]:
Day1FCAE = FCAE(frame_cell).to(device)
trainer(train_loader, valid_loader, Day1FCAE, config, device)

Epoch [0/10000]: 100%|██████████| 162/162 [00:02<00:00, 61.91it/s, loss=0.999]


Epoch [1/10000]: Train loss: 1.0052, Valid loss: 0.9950
Saving model with loss 0.995...


Epoch [1/10000]: 100%|██████████| 162/162 [00:02<00:00, 73.08it/s, loss=0.994]


Epoch [2/10000]: Train loss: 0.9973, Valid loss: 0.9890
Saving model with loss 0.989...


Epoch [2/10000]: 100%|██████████| 162/162 [00:02<00:00, 73.74it/s, loss=0.997]


Epoch [3/10000]: Train loss: 0.9953, Valid loss: 0.9806
Saving model with loss 0.981...


Epoch [3/10000]: 100%|██████████| 162/162 [00:02<00:00, 75.39it/s, loss=0.993]


Epoch [4/10000]: Train loss: 0.9944, Valid loss: 0.9779
Saving model with loss 0.978...


Epoch [4/10000]: 100%|██████████| 162/162 [00:01<00:00, 81.22it/s, loss=0.982]


Epoch [5/10000]: Train loss: 0.9940, Valid loss: 0.9623
Saving model with loss 0.962...


Epoch [5/10000]: 100%|██████████| 162/162 [00:02<00:00, 75.79it/s, loss=0.981]


Epoch [6/10000]: Train loss: 0.9778, Valid loss: 0.9487
Saving model with loss 0.949...


Epoch [6/10000]: 100%|██████████| 162/162 [00:01<00:00, 85.53it/s, loss=0.966]


KeyboardInterrupt: 

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=./runs/