## Import libraries

In [1]:
import os
import pandas as pd
import numpy as np
import copy
import torch
import loss
from torch import optim
from metrics import eval_metrics, get_epoch_acc
from dataloader import DataLoader
from cross_val import CrossVal
from torchvision import transforms
from eval import eval
from config import ModelParameters
from PIL import Image
import cv2

# Import available models, you can also explore other PyTorch models
from cracknet import cracknet, CrackNet
from unet import UNet, UNetResnet
from segnet import SegNet, SegResNet



In [2]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
DEVICE = "cpu"
os.environ["TORCH_USE_CUDA_DSA"] = "1"

In [3]:
# Name the data directory and model filename
DIR = 'data/' # Data directory
MODEL_FILENAME = 'cracknet.pt' # Model filename



In [8]:
num_fault = []
num_horizon = []
num_seismic = []

for name in os.listdir('../data/aug_fault_mask'):
    if name == '.ipynb_checkpoints':
        continue
    code = name.replace("fault","").replace(".npy","")
    if os.path.isfile('../data/aug_horizon_mask/horizon{}.npy'.format(code)) and os.path.isfile('../data/aug_raw_seismic/seismic{}.png'.format(code)):
        num_fault.append(name)
        num_horizon.append('horizon{}.npy'.format(code))
        num_seismic.append('seismic{}.png'.format(code))
    

df = pd.DataFrame({
    'RAW_SEISMIC': [f"{x}" for x in num_seismic],
    'RAW_FAULT': [f"{x}" for x in num_fault],
    'RAW_HORIZON': [f"{x}" for x in num_horizon]
})

In [9]:
cv = CrossVal(df, 3)
dataloaders = cv
# dataset = {}
# dataset['train'] = LabelMe(data_folder=os.path.join(DIR,'train'), transform=data_transforms['train'],
#                                 img_size=(1024, 1024))
# dataset['val'] = LabelMe(data_folder=os.path.join(DIR,'val'), transform=data_transforms['val'],
#                                 img_size=(1024, 1024))
# dataloaders = {x: torch.utils.data.DataLoader(dataset[x], batch_size = BATCH_SIZE,
#                                             shuffle = True, num_workers = 8, 
#                                             drop_last = False)
#                                             for x in ['train', 'val']}
class_count = len(dataloaders[0]['train'].dataset.label)

In [10]:
# Choose a model for training, you can refer to the models that have been imported above
model = cracknet(pretrained = ModelParameters.PRETRAINED, num_classes = class_count)

my_optimizer = optim.Adam(model.parameters(), lr = ModelParameters.LEARNING_RATE) # Check https://pytorch.org/docs/stable/optim.html for other optimizers
my_lr_scheduler = optim.lr_scheduler.StepLR(my_optimizer, step_size=25, gamma=0.1) # Check https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate for other schedulers

## Start model training

In [4]:
num_fault = []
num_horizon = []
num_seismic = []

for name in os.listdir('../data/fault_mask'):
    if name == '.ipynb_checkpoints':
        continue
    code = name.replace("fault","").replace(".npy","")
    if os.path.isfile('../data/horizon_mask/horizon{}.npy'.format(code)) and os.path.isfile('../data/raw_seismic/seismic{}.png'.format(code)):
        num_fault.append(name)
        num_horizon.append('horizon{}.npy'.format(code))
        num_seismic.append('seismic{}.png'.format(code))
    

df = pd.DataFrame({
    'RAW_SEISMIC': [f"{x}" for x in num_seismic],
    'RAW_FAULT': [f"{x}" for x in num_fault],
    'RAW_HORIZON': [f"{x}" for x in num_horizon]
})

In [5]:
cv = CrossVal(df, 3)
dataloaders = cv
# dataset = {}
# dataset['train'] = LabelMe(data_folder=os.path.join(DIR,'train'), transform=data_transforms['train'],
#                                 img_size=(1024, 1024))
# dataset['val'] = LabelMe(data_folder=os.path.join(DIR,'val'), transform=data_transforms['val'],
#                                 img_size=(1024, 1024))
# dataloaders = {x: torch.utils.data.DataLoader(dataset[x], batch_size = BATCH_SIZE,
#                                             shuffle = True, num_workers = 8, 
#                                             drop_last = False)
#                                             for x in ['train', 'val']}
class_count = len(dataloaders[0]['train'].dataset.label)

In [7]:
PATH = "../shared/cracknet baseline focal lost.pt.pt"
model = torch.load(PATH, weights_only=False)

In [8]:
idx = 200
dl = DataLoader([],[],[], (512, 512))
img = []
img.append(dl.input_transforms(Image.open(os.path.join(dl.RAW_SEISMIC_FOLDER, df.at[idx, "RAW_SEISMIC"]))))
img = torch.stack(img, dim=0)
mask = []
mask.append(torch.from_numpy(cv2.resize(np.load(os.path.join(dl.FAULT_MASK_FOLDER, df.at[idx, "RAW_FAULT"])).astype(np.uint8), dl.IMAGE_SIZE, cv2.INTER_AREA).astype(np.int64)))
mask = torch.stack(mask, dim=0)
model = model.to(DEVICE)
img = img.to(DEVICE)
mask = mask.to(DEVICE)
mask_pred = model(img)

In [9]:
eval_metrics(mask_pred, mask, class_count, ModelParameters.EVAL_METRIC)

(array(257541), array(262144))

In [25]:
mask_pred1 = mask_pred[0]
mask_pred1.shape

torch.Size([2, 512, 512])

In [26]:
_, predict1 = torch.max(mask_pred1.data, 0)
predict1_np = predict1.numpy()
inverted_fault_mask = (255 - predict1_np * 255).astype('uint8')
transform = transforms.Resize((512, 512))
img = Image.open(os.path.join(dl.RAW_SEISMIC_FOLDER, df.at[idx, "RAW_SEISMIC"]))
img = np.asarray(transform(img))
# Convert binary image to a 3-channel image for overlay (BGR format)
fault_mask_bgr = cv2.cvtColor(inverted_fault_mask, cv2.COLOR_GRAY2BGR)

overlay = cv2.addWeighted(img, 0.5, fault_mask_bgr, 0.5, 0)

cv2.imwrite("overlay_img_fault.png", overlay)

True

In [27]:
import torch.nn.functional as F

In [29]:
def make_one_hot(labels, classes):
    one_hot = torch.FloatTensor(labels.size()[0], classes, labels.size()[2], labels.size()[3]).zero_().to(labels.device)
    target = one_hot.scatter_(1, labels.data, 1)
    return target

In [36]:
target = make_one_hot(mask.unsqueeze(dim=1), classes=mask_pred.size()[1])
output = F.softmax(mask_pred, dim=1)
output_flat = output.contiguous().view(-1)
target_flat = target.contiguous().view(-1)
intersection = (output_flat * target_flat).sum()
loss = 1 - ((2. * intersection + 1) /
            (output_flat.sum() + target_flat.sum() + 1))
loss

tensor(0.3761, grad_fn=<RsubBackward1>)

In [37]:
output_flat

tensor([0.7146, 0.5000, 0.5000,  ..., 0.5019, 0.3138, 0.5000],
       grad_fn=<ViewBackward0>)

In [38]:
mask_pred

tensor([[[[0.9177, 0.0000, 0.0000,  ..., 0.8915, 0.8529, 0.0000],
          [0.9456, 0.0000, 0.0000,  ..., 0.9932, 1.0765, 0.0000],
          [0.8455, 0.0000, 0.9479,  ..., 0.0000, 0.8217, 1.1003],
          ...,
          [1.0743, 0.0000, 0.0000,  ..., 0.8481, 0.9486, 0.0000],
          [0.8374, 0.0000, 1.0330,  ..., 0.9435, 0.9598, 0.7622],
          [0.9051, 0.0000, 0.9410,  ..., 0.0000, 0.7825, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0231, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0242,  ..., 0.0037, 0.0000, 0.0000],
          ...,
          [0.0425, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0060, 0.0000, 0.0075,  ..., 0.0414, 0.0738, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0075, 0.0000, 0.0000]]]],
       grad_fn=<ReluBackward0>)

In [40]:
output_flat

tensor([0.7146, 0.5000, 0.5000,  ..., 0.5019, 0.3138, 0.5000],
       grad_fn=<ViewBackward0>)

In [41]:
target_flat

tensor([1., 1., 1.,  ..., 0., 0., 0.])