In [1]:
# imports
import os
import torch
from DataHandlers import *
#from nn_models.NN_model_BN import *
from nn_models.lstm_unet import UNet_ConvLSTM
from nn_models.DBlink_NN import *

from Trainers_ULM import *
from Utils import *
from demo_exp_params import *
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import torchvision
from utils.utilities import *
# from torchview import draw_graph


In [2]:
# Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device", device)
path = r'./' # Path to model

tmp_result_dir_exist = os.path.exists("./tmp_results")
if not tmp_result_dir_exist:
   # Create a tmp_results dir because it does not exist
   os.makedirs("./tmp_results")

Using device cuda:0


### Training the model

In [3]:


method = 'DBlinkBase'  # options : DeepSMV / DBlink / Reg_Unet / DBlinkBase

TrainNetFlag = False

if(TrainNetFlag):
    X_train = torch.load('X_train')
    y_train = torch.load('y_train')
    X_val = torch.load('X_val')
    y_val = torch.load('y_val')

if method == 'DBlink':
    
    # model parameters
    model_name = 'DBlink_model'
    img_size = 128 
    num_layers = 2 # The number of LSTM layers
    hidden_channels = 4 # The hidden layer number of channels at the output of each lstm cell. Purpose?: adding more combinations of features? -> Higher complexity. It's more features.
    window_size = 12 # The number of used windows (in each direction) for the inference of each reconstructed frame

    model = ConvOverlapBLSTM(input_size=(img_size, img_size), input_channels=1, hidden_channels=hidden_channels, num_layers=num_layers, device=device).to(device)  #DBlink
    
    if(TrainNetFlag):
        
        #training parameters 
        lr = 1e-3 # Training learning rate #was 1e-4
        betas = (0.99, 0.999) # Parameters of Adam optimizer
        epochs = 40
        batch_size =4
        patience = 5 # was 8
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas)
        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=patience, min_lr=1e-9, verbose=True)
        
        # loading data 
        
        y_train = y_train[:,[4],:]
        y_val = y_val[:,[4],:]
        
        dl_train = CreateDataLoader(X_train, y_train, batch_size=batch_size)
        dl_val = CreateDataLoader(X_val, y_val, batch_size=batch_size)

        trainer = DBlink_trainer(model, criterion, optimizer, scheduler, batch_size, window_size=window_size,
                                   vid_length=X_train.shape[1], patience=patience, device=device, modelname = model_name)
        trainer.fit(dl_train, dl_val, num_epochs=epochs)
        #torch.save(model.state_dict(), model_name)
        
    else:  # Testing
        model.load_state_dict(torch.load(model_name, map_location=torch.device(device)))
        

#------------------------------------------------------------------------------------------
        
    

elif method == 'DBlinkBase':
    
    # model parameters
    model_name = 'DBlinkBase_model'
    img_size = 32
    num_layers = 2 # The number of LSTM layers
    hidden_channels = 4 # The hidden layer number of channels at the output of each lstm cell. Purpose?: adding more combinations of features? -> Higher complexity. It's more features.
    window_size = 25 # The number of used windows (in each direction) for the inference of each reconstructed frame

    model = ConvOverlapBLSTM(input_size=(img_size, img_size), input_channels=1, hidden_channels=hidden_channels, num_layers=num_layers, device=device).to(device)  #DBlink
    
    if(TrainNetFlag):
        
        #training parameters 
        lr = 1e-4 # Training learning rate #was 1e-4
        betas = (0.99, 0.999) # Parameters of Adam optimizer
        epochs = 1
        batch_size = 16
        patience = 3 # was 8
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas)
        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=patience, min_lr=1e-9, verbose=True)
        
        # loading data 
        if(TrainNetFlag):
            
            X_train = torch.load('BaseX_train')
            y_train = torch.load('Basey_train')
            X_val = torch.load('BaseX_val')
            y_val = torch.load('Basey_val')
        
        #y_train = y_train[:,[4],:]
        #y_val = y_val[:,[4],:]
        
        dl_train = CreateDataLoader(X_train, y_train, batch_size=batch_size)
        dl_val = CreateDataLoader(X_val, y_val, batch_size=batch_size)

        trainer = LSTM_overlap_Trainer(model, criterion, optimizer, scheduler, batch_size, window_size=window_size,
                                   vid_length=X_train.shape[1], patience=patience, device=device, modelname = model_name)
        trainer.fit(dl_train, dl_val, num_epochs=epochs)
        #torch.save(model.state_dict(), model_name)
        
    else:  # Testing
        model.load_state_dict(torch.load(model_name, map_location=torch.device(device)))
    
    
 #----------------------------------------------------------------------------------------------------------   
    
elif method == "DeepSMV":
    
    # model parameters
    model_name = 'DeepSMV_model' 
    num_lstm_layers = 1 
    in_ch = 1  
    out_ch = 1 

    #model = UNet_ConvLSTM(n_channels=1, n_classes=1, use_LSTM=True, parallel_encoder=False, lstm_layers=1).to(device)  #DeepSMV original 
    model = UNet_ConvLSTM(n_channels= in_ch, n_classes= out_ch, use_LSTM=True, parallel_encoder=False, lstm_layers= num_lstm_layers).to(device)  #DeepSMV original 
    
    if(TrainNetFlag):
        
        #training parameters 
        lr = 1e-3  
        betas = (0.99, 0.999) 
        epochs = 40
        batch_size =8
        patience = 3 # was 8
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas)
        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=patience, min_lr=1e-9, verbose=True)
        
        # loading data 
        y_train = y_train[:,[4],:]  #velocitymap
        y_val = y_val[:,[4],:]      #velocitymap
        
        dl_train = CreateDataLoader(X_train, y_train, batch_size=batch_size)
        dl_val = CreateDataLoader(X_val, y_val, batch_size=batch_size)

        trainer = Deepsmv_trainer(model, criterion, optimizer, scheduler, batch_size,
                                   vid_length=X_train.shape[1], patience=patience, device=device ,modelname = model_name)
        trainer.fit(dl_train, dl_val, num_epochs=epochs)
        #torch.save(model.state_dict(), model_name)
        
    else:  # Testing
        model.load_state_dict(torch.load(model_name, map_location=torch.device(device)))
        
        
#---------------------------------------------------------------------------------------------------------------------------------------
    
elif method == 'Reg_Unet':
    
    # model parameters
    model_name = 'Reg_Unet_model' 
    num_lstm_layers = 1 
    in_ch = 2
    out_ch = 1 
    LSTM_used = False

    #model = UNet_ConvLSTM(n_channels=1, n_classes=1, use_LSTM=True, parallel_encoder=False, lstm_layers=1).to(device)  #DeepSMV original 
    model = UNet_ConvLSTM(n_channels= in_ch, n_classes= out_ch, use_LSTM= LSTM_used, parallel_encoder=False, lstm_layers= num_lstm_layers).to(device)  #DeepSMV original 
    
    if(TrainNetFlag):
        
        #training parameters 
        lr = 1e-3 
        betas = (0.99, 0.999) 
        epochs = 100
        batch_size = 4
        patience = 5 # was 8
        #criterion = nn.L1Loss()
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr=lr, betas=betas)
        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=patience, min_lr=1e-9, verbose=True)
        
        # loading data 
        X_train = y_train[:,[3,1],:]
        X_val = y_val[:,[3,1],:]
        
        y_train = y_train[:,[4],:]  #velocitymap
        y_val = y_val[:,[4],:]      #velocitymap
        
        dl_train = CreateDataLoader(X_train, y_train, batch_size=batch_size)
        dl_val = CreateDataLoader(X_val, y_val, batch_size=batch_size)

        trainer = Deepsmv_trainer(model, criterion, optimizer, scheduler, batch_size,
                                   vid_length=X_train.shape[1], patience=patience, device=device,modelname = model_name)
        trainer.fit(dl_train, dl_val, num_epochs=epochs )
        #torch.save(model.state_dict(), model_name)
        
    else:  # Testing
        model.load_state_dict(torch.load(model_name, map_location=torch.device(device)))
    
else:
    print('check model name')

In [None]:
print(trainer.up)
print(trainer.down)
print(trainer.out_ind)

## Testing model
## SKIP ONLY THIS CELL BELOW IF U WANNA TEST DBLINK !!


In [None]:


X_test = torch.load('X_test')
y_test = torch.load('y_test')

X_test = X_test.to(device)
y_test = y_test.to(device)
    
#model_name = 'best_model'
#model.load_state_dict(torch.load(model_name, map_location=torch.device(device)))

GT = y_test    
model.eval()

with torch.no_grad():
        
    if method == 'DBlink':   
        GT = y_test[:,4,:] 
        out = model(X_test,torch.flip(X_test, dims=[1]))
        realout = out[:,window_size,0]
    elif method =='DeepSMV':
        GT = y_test[:,[4],:]
        realout,_ = model(X_test)
        #realout = realout[:,0,:,:]
    elif method == 'Reg_Unet':
        GT = y_test[:,[4],:] 
        print(GT.shape)
        
        realout,_ = model(y_test[:,[3,1],:])
        print(realout.shape)
        
        
    else:
        print('check method name')    
    #print(out.shape)
    #print(realout.shape)
    #print(y_test.shape)
        
    criterion = nn.MSELoss()
    Metric=criterion(realout,GT) # MSE 
    print(np.sqrt(Metric.item())) 
    
    # realout = realout.detach().cpu().numpy()
    # GT = GT.detach().cpu().numpy()

    
folder = r'./tmp_results'
        
realout = normalize_images(realout)
print(realout.shape)
print(realout.dtype)
GT = normalize_images(GT)

tracks = normalize_images(y_test[:,[3],:])
        
#realout = realout[:,None,:]
#GT = GT[:,None,:]
        
torchvision.utils.save_image(realout, f"{folder}/inference.png")
torchvision.utils.save_image(GT, f"{folder}/GT.png")
torchvision.utils.save_image(tracks, f"{folder}/input.png")

## DBLink testing HERE

In [4]:

X_test = torch.load('BaseX_test')
y_test = torch.load('Basey_test')

X_test = X_test.to(device)
y_test = y_test.to(device)

N, T, C, H, W = X_test.shape

model_name = 'DBlinkBase_model'

model = ConvOverlapBLSTM(input_size=(img_size, img_size), input_channels=1, hidden_channels=hidden_channels, num_layers=num_layers, device=device).to(device)
model.load_state_dict(torch.load(os.path.join(path, model_name), map_location=torch.device(device)))



<All keys matched successfully>

In [7]:
down = torch.zeros(X_test.size(1), requires_grad=False, dtype=torch.int)
up = torch.zeros(X_test.size(1), requires_grad=False, dtype=torch.int)
out_ind = torch.zeros(X_test.size(1), requires_grad=False, dtype=torch.int)

criterion = nn.MSELoss()
sum_factor = 1

for i in range(X_test.size(1)):
    down[i] = torch.max(torch.IntTensor([0, i -  window_size]))
    up[i] = torch.min(torch.IntTensor([X_test.size(1), i +  window_size]))
    out_ind[i] = i - down[i]

for i in range(X_test.size(0)):
    #i = 3
    out = []
    print('Forward pass through the network')
    with torch.no_grad():
        for j in tqdm(range(X_test.shape[1])):
            curr_out = model(X_test[i:i + 1, down[j]:up[j]:sum_factor],
                              torch.flip(X_test[i:i + 1, down[j]:up[j]:sum_factor], dims=[1]))
            curr_out = curr_out.detach().cpu()[0, int(out_ind[j] / sum_factor)]
            out.append(curr_out)

    out = torch.stack(out, dim=1)
    temp_output = out.to(device)

    metric = criterion(temp_output,y_test[[i],:,0,:])
    print('M.S.E.:  ',metric.item()) 
    del temp_output

    curr_vid = np.zeros([1, X_test.size(1), C, H, W])
    for j in tqdm(range(X_test.size(1))):
        curr_vid[0, j] = 255 * normalize_input_01(out[0, j].numpy())

    np.save('tmp_results/np_vid_{}'.format(i + 1), curr_vid[0, :-2*window_size])
    np.save('tmp_results/gt_vid_{}'.format(i + 1), y_test[i, :-2*window_size].detach().cpu().numpy())

    print("-I- Completed vid", i + 1)
    
    #break
    
    ## CANNOT POST PROCESS_RESULTS IN GPU CLUSTER, do this line below on your own pc

    # Post process reconstruction and generate output video
    post_process_results(r'./tmp_results', i + 1)
    

Forward pass through the network


100% 300/300 [00:32<00:00,  9.37it/s]


M.S.E.:   631.3927001953125


100% 300/300 [00:00<00:00, 48251.06it/s]


-I- Completed vid 1
Forward pass through the network


100% 300/300 [00:31<00:00,  9.38it/s]


M.S.E.:   640.6298828125


100% 300/300 [00:00<00:00, 49580.01it/s]


-I- Completed vid 2
Forward pass through the network


100% 300/300 [00:32<00:00,  9.37it/s]


M.S.E.:   2763.0439453125


100% 300/300 [00:00<00:00, 42473.96it/s]


-I- Completed vid 3
Forward pass through the network


100% 300/300 [00:31<00:00,  9.38it/s]


M.S.E.:   5.891513347625732


100% 300/300 [00:00<00:00, 49768.27it/s]

-I- Completed vid 4



