In [None]:
import os
import time
from tqdm import tqdm
import numpy as np

from datetime import datetime

import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import transforms
from torch.utils.data import DataLoader
from torchsummary import summary

import utils, draw
from models import net3d
from data.dataloader import Dataset
from data.augments import Reshape, ToTensor
from infer import infer 

# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

# Session name
session_name = datetime.now().strftime('%b%d_%H%M%S')

In [None]:
# Session path
root_path = os.path.abspath(".")
sessions_path = os.path.join(root_path, "sessions")
session_name = '_'.join(("session", session_name,'test'))
session_path = os.path.join(sessions_path, session_name)
picture_path = os.path.join(session_path, "picture")
bin_path = os.path.join(session_path, "bin")

utils.makeDir(session_path)
utils.makeDir(picture_path)
utils.makeDir(bin_path)

### Load CNN model

In [None]:
# Parameters
n1, n2, n3, n_channels = 256, 256, 128, 1
dataset_name = "demo"
data_root_dir = os.path.join(root_path, "datasets", dataset_name)
data_path = os.path.join(data_root_dir, "seis")
data_list = os.listdir(data_path)
list_IDs = utils.sort_list_IDs(data_list)
only_load_input = True

# Dataset
dataset = Dataset(root_dir=data_root_dir, list_IDs=list_IDs,
                  transform=transforms.Compose([
                      Reshape((n1, n2, n3, n_channels)),
                      ToTensor(),
                  ]),
                  only_load_input=only_load_input)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
dataloader_val = None

In [None]:
# Define CNN model
param_model = {}
param_model['input_channels'] = 1
param_model['encoder_channels'] = 512                                                                                                                     
param_model['decoder_channels'] = 16
param_model['trained_model_path'] = os.path.join(root_path, "checkpoints", 
                                                 "trained_RGTNet_parameters.pth")
model = net3d.model(param_model)

# Load trained model parameters
if use_cuda: 
    model.load_state_dict(torch.load(param_model['trained_model_path']))
else:
    model.load_state_dict(torch.load(param_model['trained_model_path'], map_location='cpu'))

params = list(model.named_parameters())

# Send CNN model to GPU or CPU
if use_cuda:
    num_GPU = torch.cuda.device_count()
    model = torch.nn.DataParallel(model, device_ids=range(num_GPU)).to(device)
else:
    print(f"CPU mode")
    model = model.to(device)

### Prediction

In [None]:
# Inference
pared_sample_list = infer(model, dataloader, only_load_input, bin_path, picture_path, device)

In [None]:
num_sample = len(pared_sample_list)
random_idex_sample = np.random.randint(num_sample) 

In [None]:
seis = pared_sample_list[random_idex_sample]['seis']
draw.draw_slice(seis, x_slice=30, y_slice=30, z_slice=120, cmap='gray')

In [None]:
pred = utils.min_max_norm(pared_sample_list[random_idex_sample]['pred_rgt'])
draw.draw_slice_surf(seis, volume2=pred, x_slice=30, y_slice=30, z_slice=120, 
                     cmap='gray', isofs=[0.25,0.5,0.75])