<a href="https://colab.research.google.com/github/SAHIL9581/w2w/blob/main/W2W_Pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import shutil

# --- 1. SETUP THE TEMPORARY ENVIRONMENT AND WORKSPACE ---
PROJECT_PATH = "/content/W2W_Pipeline"
print(f"--> Creating project workspace at: {PROJECT_PATH}")

# Clean up previous runs if they exist
if os.path.exists(PROJECT_PATH):
    shutil.rmtree(PROJECT_PATH)

os.makedirs(f"{PROJECT_PATH}/data/raw_las_files", exist_ok=True)
os.makedirs(f"{PROJECT_PATH}/artifacts", exist_ok=True)
os.makedirs(f"{PROJECT_PATH}/trained_models/autoencoder", exist_ok=True)
os.makedirs(f"{PROJECT_PATH}/trained_models/boundary_detector", exist_ok=True)
os.makedirs(f"{PROJECT_PATH}/results", exist_ok=True) # For saving plots

# Change the current working directory to the project path
%cd {PROJECT_PATH}
print(f"--> Successfully changed directory to: {os.getcwd()}")


# --- 2. INSTALL ALL REQUIRED LIBRARIES ---
print("\n--> Installing necessary Python libraries (this may take a few minutes)...")
!pip install -U "ray[train,tune]>=2.9.0" mlflow torch torchvision torchaudio lasio scikit-learn pandas tqdm matplotlib joblib pyyaml pyngrok -q
print("✅ Library installation complete.")

--> Creating project workspace at: /content/W2W_Pipeline
/content/W2W_Pipeline
--> Successfully changed directory to: /content/W2W_Pipeline

--> Installing necessary Python libraries (this may take a few minutes)...
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.2/91.2 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m70.1/70.1 MB[0m [31m11.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.7/24.7 MB[0m [31m84.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m74.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m821.2/821.2 MB[0m [31m1.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m393.1/393.1 MB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.9/8.9

In [2]:
# After running the previous cell, execute this cell and then
# GO TO "Runtime" -> "Restart runtime..." in the Colab menu.
# Confirm the restart, and then proceed to the next cell.

print("!!! IMPORTANT: Please RESTART YOUR COLAB RUNTIME NOW !!!")
print("Go to 'Runtime' -> 'Restart runtime...' in the Colab menu.")
print("Once restarted, continue to the next cell.")

!!! IMPORTANT: Please RESTART YOUR COLAB RUNTIME NOW !!!
Go to 'Runtime' -> 'Restart runtime...' in the Colab menu.
Once restarted, continue to the next cell.


In [3]:
#@title ⚙️ Configure Your Pipeline Run
#@markdown ### 1. Select Which Pipelines to Run
run_data_preparation = True #@param {type:"boolean"}
run_pretraining = True #@param {type:"boolean"}
run_finetuning = True #@param {type:"boolean"}
run_inference = True #@param {type:"boolean"}

#@markdown ---
#@markdown ### 2. Configure Model and Data Parameters
#@markdown **Important:** Set the number of feature columns (curves) from your LAS files.
input_channels = 13 #@param {type:"integer"}
#@markdown Set the patch height (window size) for processing well logs.
patch_height = 700 #@param {type:"integer"}

#@markdown ---
#@markdown ### 3. Configure Training Parameters
#@markdown Hyperparameter tuning samples for pre-training.
pretrain_num_samples = 10 #@param {type:"integer"}
#@markdown Number of epochs for final model fine-tuning.
finetune_epochs = 100 #@param {type:"integer"}

#@markdown ---
#@markdown ### 4. Configure Inference Parameters
#@markdown Provide the exact WELL names to compare (find in `data/train.csv` after data prep).
reference_well = "WELL_NAME_A" #@param {type:"string"}
well_of_interest = "WELL_NAME_B" #@param {type:"string"}
#@markdown Confidence threshold (0 to 1) for a predicted boundary to be considered valid.
correlation_threshold = 0.7 #@param {type:"number"}

# ==============================================================================
# All imports needed for the notebook
import pandas as pd
import numpy as np
import lasio
import os
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
from sklearn.preprocessing import StandardScaler
from joblib import dump, load
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.optimize import linear_sum_assignment
import textwrap
import yaml
import argparse
import shutil
import mlflow
from ray import tune
from ray.train import RunConfig, ScalingConfig, Checkpoint
from ray.train.torch import TorchTrainer, TorchCheckpoint
from ray.air.integrations.mlflow import MLflowLoggerCallback
from google.colab import files
from pyngrok import ngrok

# ==============================================================================
# Authenticate ngrok with your API key
ngrok.set_auth_token("2zsOgKyrXbZYrtq5ojgMgwih7AQ_4QwzvhaUh1PL7MYgYB9nY")
# ==============================================================================

# Create the configuration dictionary from the form values
config = {
    'run_data_preparation': run_data_preparation,
    'run_pretraining': run_pretraining,
    'run_finetuning': run_finetuning,
    'run_inference': run_inference,
    'paths': {
        'raw_las_folder': "data/raw_las_files/",
        'processed_csv_path': "data/train.csv",
        'label_encoder_path': "artifacts/label_encoder.json",
        'std_scaler_path': "artifacts/StandardScaler.bin",
        'pretrained_encoder_path': "trained_models/autoencoder/best_autoencoder.pt",
        'final_model_path': "trained_models/boundary_detector/final_model.pt",
        'results_path': "results/"
    },
    'mlflow': {
        'experiment_name': "W2W_Matcher_Pipeline"
    },
    'pretraining': {
        'epochs': 25,
        'num_samples': pretrain_num_samples,
        'search_space': {
            'in_channels': tune.grid_search([input_channels]), # Use the value from the form
            'optimizer': tune.choice(["RMSprop", "AdamW", "Adam"]),
            'lr': tune.loguniform(1e-4, 1e-3),
            'act_name': tune.choice(["prelu", "relu"]),
            'batch_size': tune.choice([16, 32]),
        }
    },
    'finetuning': {
        'learning_rate': 0.0001,
        'batch_size': 16,
        'epochs': finetune_epochs,
        'model_params': {
            'patch_height': patch_height,
            'in_channels': input_channels,
            'act_name': "prelu",
            'project_in_features': 2048,
            'hidden_dim': 256,
            'num_queries': 100,
            'num_heads': 8,
            'dropout': 0.1,
            'expansion_factor': 4,
            'num_transformers': 6,
            'output_size': 3 # [class_prob, top, height]
        },
        'matcher_costs': {'set_cost_class': 1, 'set_cost_bbox': 5},
        'loss_weights': {'loss_matching': 1.0, 'loss_unmatching': 0.5, 'loss_height_constraint': 0.5}
    },
    'inference': {
        'reference_well': reference_well,
        'well_of_interest': well_of_interest,
        'correlation_threshold': correlation_threshold,
    }
}

print("✅ Configuration loaded successfully.")
print(f"Project directory: {os.getcwd()}")

✅ Configuration loaded successfully.
Project directory: /content/W2W_Pipeline


In [None]:
print(">>> ACTION REQUIRED: Please upload the ZIP file containing your .las files.")
uploaded_files = files.upload()

if not uploaded_files:
    print("\n⚠️ Upload was cancelled or failed. Please run this cell again.")
elif len(uploaded_files) > 1:
    print("\n⚠️ Please upload only a single ZIP file. Run this cell again.")
else:
    zip_filename = list(uploaded_files.keys())[0]
    print(f"\n✅ '{zip_filename}' uploaded successfully.")
    print("--> Unzipping into the 'data/raw_las_files' folder...")
    # Use -o to overwrite existing files without prompting
    !unzip -q -o "{zip_filename}" -d data/raw_las_files/
    print("--> ZIP file has been unzipped successfully.")
    # Clean up the uploaded zip file
    os.remove(zip_filename)
    print("\n✅ Data input step is complete. You can now proceed.")

>>> ACTION REQUIRED: Please upload the ZIP file containing your .las files.


In [None]:
def run_data_preparation(config):
    """Processes raw .las files into a standardized CSV and creates artifacts."""
    print("--- LAUNCHING STAGE 0: DATA PREPARATION ---")
    paths = config['paths']
    search_folder = paths['raw_las_folder']
    items_in_folder = os.listdir(search_folder)

    if len(items_in_folder) == 1 and os.path.isdir(os.path.join(search_folder, items_in_folder[0])):
        print(f"--> Found single sub-folder '{items_in_folder[0]}'. Adjusting search path.")
        search_folder = os.path.join(search_folder, items_in_folder[0])

    all_wells_df, las_files_found = [], []
    print(f"--> Searching for .las files in '{search_folder}'...")
    for root, dirs, files in os.walk(search_folder):
        for file in files:
            if file.lower().endswith('.las'):
                las_files_found.append(os.path.join(root, file))

    if not las_files_found:
        raise FileNotFoundError(f"No .las files found in '{search_folder}'. Check your ZIP or folder structure.")

    print(f"--> Found {len(las_files_found)} .las files. Reading now...")
    for filepath in las_files_found:
        try:
            las = lasio.read(filepath)
            df = las.df().reset_index()
            df['WELL'] = las.well.WELL.value if las.well.WELL.value else os.path.splitext(os.path.basename(filepath))[0]
            df['GROUP'] = 'UNKNOWN'
            for param in las.params:
                if 'GROUP' in param.mnemonic:
                    df['GROUP'] = param.value
                    break
            all_wells_df.append(df)
        except Exception as e:
            print(f"    - Could not read {filepath}: {e}")

    master_df = pd.concat(all_wells_df, ignore_index=True)
    if 'DEPT' in master_df.columns:
        master_df.rename(columns={'DEPT':'DEPTH_MD'}, inplace=True)
    if 'DEPTH' in master_df.columns:
        master_df.rename(columns={'DEPTH':'DEPTH_MD'}, inplace=True)

    master_df.to_csv(paths['processed_csv_path'], index=False, sep=';')
    print(f"--> Saved combined data to '{paths['processed_csv_path']}'")

    label_encoder = {str(g): i for i, g in enumerate(master_df['GROUP'].unique())}
    with open(paths['label_encoder_path'], 'w') as f:
        json.dump(label_encoder, f, indent=4)
    print(f"--> Saved label encoder to '{paths['label_encoder_path']}'")

    numeric_df = master_df.drop(columns=['WELL', 'GROUP', 'DEPTH_MD'], errors='ignore')
    for col in numeric_df.columns:
        numeric_df[col] = numeric_df[col].fillna(numeric_df[col].mean())

    scaler = StandardScaler()
    scaler.fit(numeric_df)
    dump(scaler, paths['std_scaler_path'])
    print(f"--> Saved StandardScaler to '{paths['std_scaler_path']}'")
    print("✅ Data Preparation complete.")

In [None]:
def collate_fn(batch):
    """Custom collate function to handle variable-sized targets in a batch."""
    images, targets = zip(*batch)
    return torch.stack(images), list(targets)

class AutoencoderDataset(Dataset):
    """Dataset for pre-training the U-Net encoder via denoising."""
    def __init__(self, config):
        paths = config['paths']
        df = pd.read_csv(paths['processed_csv_path'], delimiter=';')
        cols_drop = ['WELL', 'GROUP', 'DEPTH_MD']
        df.drop(columns=cols_drop, inplace=True, errors='ignore')

        for col in df.columns:
            df[col] = df[col].fillna(df[col].mean())

        scaler = load(paths['std_scaler_path'])
        self.data = scaler.transform(df).astype(np.float32)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        return torch.from_numpy(sample), torch.from_numpy(sample)

class BoundaryDataset(Dataset):
    """Dataset for fine-tuning the transformer for boundary detection."""
    def __init__(self, config, well_name=None, seed=None):
        self.finetune_params = config['finetuning']['model_params']
        self.paths = config['paths']
        self.well_name = well_name
        self.seed = seed if seed else np.random.randint(2**32 - 1)
        self.x, self.gt = self._get_Xy()

    def _get_ground_truth_boundaries(self, y_labels):
        """Calculates boundary locations (top, height) from a list of group labels."""
        gt_boxes_all_patches = []
        for n, y_patch in enumerate(y_labels):
            gt_boxes, count = {}, 0
            change_indices = [i + 1 for i in range(len(y_patch) - 1) if not y_patch[i] == y_patch[i+1]]
            change_indices.insert(0, 0)
            groups = [y_patch[idx] for idx in change_indices]
            tops = change_indices.copy()
            change_indices.append(len(y_patch))
            heights = [end - start for (start, end) in zip(change_indices[:-1], change_indices[1:])]

            for top, height, group in zip(tops, heights, groups):
                gt_boxes[count] = {'Group': int(group), 'Top': int(top), 'Height': int(height)}
                count += 1
            gt_boxes_all_patches.append(gt_boxes)
        return gt_boxes_all_patches

    def _get_Xy(self):
        """Loads and processes data for a single well."""
        df_master = pd.read_csv(self.paths['processed_csv_path'], delimiter=';')
        if self.well_name:
            well_df = df_master[df_master['WELL'] == self.well_name].copy()
        else: # Training: pick a random well
            np.random.seed(self.seed)
            well_names = list(df_master.WELL.unique())
            rand_idx = np.random.randint(0, len(well_names))
            well_df = df_master[df_master['WELL'] == well_names[rand_idx]].copy()

        with open(self.paths['label_encoder_path']) as f:
            label_encoder = json.load(f)

        well_df.loc[:, 'GROUP'] = well_df['GROUP'].astype(str).map(label_encoder).bfill().ffill()
        labels = well_df['GROUP'].copy()
        cols_to_drop = ['WELL', 'GROUP', 'DEPTH_MD']
        well_numeric = well_df.drop(columns=cols_to_drop, errors='ignore')

        for col in well_numeric.columns:
            well_numeric[col] = well_numeric[col].fillna(well_numeric[col].mean())

        scaler = load(self.paths['std_scaler_path'])
        scaled_data = scaler.transform(well_numeric)

        patch_height = self.finetune_params['patch_height']
        indices = list(range(0, scaled_data.shape[0], patch_height))
        x_patched = np.asarray([scaled_data[i:i+patch_height] for i in indices if scaled_data[i:i+patch_height].shape[0] == patch_height]).astype(np.float32)
        y_patched = np.asarray([labels.values[i:i+patch_height] for i in indices if labels.values[i:i+patch_height].shape[0] == patch_height])

        return x_patched, self._get_ground_truth_boundaries(y_patched)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        """Returns one sample: an image patch and its target dictionary."""
        img_patch = np.expand_dims(self.x[idx], 0)
        gt_data = self.gt[idx]
        labels, tops, heights = [], [], []

        patch_height = self.finetune_params['patch_height']
        for i in gt_data:
            tops.append(gt_data[i]['Top'] / patch_height)
            heights.append(gt_data[i]['Height'] / patch_height)
            labels.append(1)

        target = {}
        target['labels'] = torch.tensor(labels, dtype=torch.long)
        top_tensor = torch.tensor(tops, dtype=torch.float32).view(-1, 1)
        height_tensor = torch.tensor(heights, dtype=torch.float32).view(-1, 1)
        target['loc_info'] = torch.hstack((top_tensor, height_tensor))
        return torch.from_numpy(img_patch), target

In [None]:
def get_activation(name):
    if name == 'prelu': return nn.PReLU()
    if name == 'relu': return nn.ReLU()
    return nn.GELU()

class Project(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
    def forward(self, x):
        return self.linear(x.flatten(1))

class Query(nn.Module):
    def __init__(self, num_queries, hidden_dim):
        super().__init__()
        self.queries = nn.Parameter(torch.randn(1, num_queries, hidden_dim))
    def forward(self, x):
        return self.queries.repeat(x.shape[0], 1, 1)

class Transformer(nn.Module):
    def __init__(self, hidden_dim, num_heads, dropout, expansion_factor, act_name):
        super().__init__()
        self.transformer_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim, nhead=num_heads, dropout=dropout, batch_first=True
        )
    def forward(self, query_embed, content_embed):
        return self.transformer_layer(query_embed)

class UNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2, kernel_size=3, activation='prelu'):
        super().__init__()
        padding = kernel_size // 2
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.BatchNorm2d(out_channels),
            get_activation(activation),
            nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding),
            nn.BatchNorm2d(out_channels),
            get_activation(activation)
        )
    def forward(self, x):
        return self.block(x)

class UNet(nn.Module):
    def __init__(self, in_channels=13, activation='prelu'):
        super().__init__()
        self.act = get_activation(activation)
        self.start = nn.Sequential(nn.Conv2d(in_channels, 32, 3, 1, 1), nn.BatchNorm2d(32), self.act)
        self.e1 = UNetBlock(32, 64, 2, activation=activation)
        self.e2 = UNetBlock(64, 128, 2, activation=activation)
        self.e3 = UNetBlock(128, 256, 2, activation=activation)
        self.mid = nn.Sequential(nn.Conv2d(256, 512, 2), nn.BatchNorm2d(512), self.act)
        self.uc3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.d3 = UNetBlock(512, 256, 1, activation=activation)
        self.uc2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.d2 = UNetBlock(256, 128, 1, activation=activation)
        self.uc1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.d1 = UNetBlock(128, 64, 1, activation=activation)
        self.out = nn.Conv2d(64, in_channels, 1)

    def forward(self, x):
        x = x.unsqueeze(-1).unsqueeze(-1)
        x1 = self.e1(self.start(x))
        x2 = self.e2(x1)
        x3 = self.e3(x2)
        m = self.mid(x3)
        u3 = self.d3(torch.cat((self.uc3(m, output_size=x3.size()), x3), 1))
        u2 = self.d2(torch.cat((self.uc2(u3, output_size=x2.size()), x2), 1))
        u1 = self.d1(torch.cat((self.uc1(u2, output_size=x1.size()), x1), 1))
        return self.out(u1).squeeze(-1).squeeze(-1)

class UNetEncoder(nn.Module):
    def __init__(self, in_channels=13, activation='prelu'):
        super().__init__()
        self.act = get_activation(activation)
        self.start = nn.Sequential(nn.Conv2d(in_channels, 32, 3, 1, 1), nn.BatchNorm2d(32), self.act)
        self.e1 = UNetBlock(32, 64, 2, activation=activation)
        self.e2 = UNetBlock(64, 128, 2, activation=activation)
        self.e3 = UNetBlock(128, 256, 2, activation=activation)
        self.mid = nn.Sequential(nn.Conv2d(256, 512, 2), nn.BatchNorm2d(512), self.act)

    def forward(self, x):
        x = x.unsqueeze(-1)
        x = x.permute(0, 2, 1, 3)
        x1 = self.e1(self.start(x))
        x2 = self.e2(x1)
        x3 = self.e3(x2)
        m = self.mid(x3)
        return m

class W2WTransformerModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        p = config['finetuning']['model_params']
        self.encoder = UNetEncoder(p['in_channels'], p['act_name'])
        self.project = Project(p['project_in_features'], p['hidden_dim'])
        self.query = Query(p['num_queries'], p['hidden_dim'])
        self.transformers = nn.ModuleList([
            Transformer(p['hidden_dim'], p['num_heads'], p['dropout'], p['expansion_factor'], p['act_name'])
            for _ in range(p['num_transformers'])
        ])
        self.finalize = nn.Sequential(
            nn.Linear(p['hidden_dim'], p['output_size']),
        )

    def forward(self, img):
        img = img.squeeze(1)
        content_embed = self.project(self.encoder(img))
        query_embed = self.query(content_embed)
        for transformer_layer in self.transformers:
            query_embed = transformer_layer(query_embed, content_embed)
        return self.finalize(query_embed)

def load_pretrained_encoder_weights(model, path):
    if not os.path.exists(path):
      print(f"⚠️ Pre-trained weights not found at {path}. Skipping.")
      return model
    pre_dict = torch.load(path)
    model_dict = model.state_dict()
    enc_dict = {k.replace('module.', ''): v for k, v in pre_dict.items() if any(x in k for x in ['e1', 'e2', 'e3', 'mid', 'start'])}
    enc_dict = {'encoder.' + k: v for k, v in enc_dict.items()}
    model_dict.update(enc_dict)
    model.load_state_dict(model_dict, strict=False)
    print(f"✅ Loaded {len(enc_dict)} pre-trained layers from {path}")
    return model

print("✅ Model architectures defined.")

In [None]:
class HungarianMatcher(nn.Module):
    def __init__(self, cost_class=1, cost_bbox=1):
        super().__init__()
        self.cost_class = cost_class
        self.cost_bbox = cost_bbox

    @torch.no_grad()
    def forward(self, outputs, targets):
        out_prob = outputs[:, :, :1].sigmoid()
        out_bbox = outputs[:, :, 1:]
        bs, num_queries = out_prob.shape[:2]

        flat_out_prob = out_prob.flatten(0, 1)
        flat_out_bbox = out_bbox.flatten(0, 1)

        tgt_labels = torch.cat([v["labels"] for v in targets]).to(out_prob.device)
        tgt_bbox = torch.cat([v["loc_info"] for v in targets]).to(out_bbox.device)

        cost_class = -flat_out_prob[:, 0]
        cost_bbox = torch.cdist(flat_out_bbox, tgt_bbox, p=1)
        C = (self.cost_bbox * cost_bbox + self.cost_class * cost_class).view(bs, num_queries, -1).cpu()

        sizes = [len(v["loc_info"]) for v in targets]
        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

class SetCriterion(nn.Module):
    def __init__(self, config):
        super().__init__()
        matcher_costs = config['finetuning']['matcher_costs']
        self.matcher = HungarianMatcher(matcher_costs['set_cost_class'], matcher_costs['set_cost_bbox'])
        self.loss_names = ["loss_matching", "loss_unmatching", "loss_height_constraint"]
        self.num_queries = config['finetuning']['model_params']['num_queries']
        self.weight_dict = config['finetuning']['loss_weights']

    def _get_src_permutation_idx(self, indices):
        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def loss_match(self, outputs, targets, indices):
        idx = self._get_src_permutation_idx(indices)
        src_bbox = outputs['pred_locs'][idx]
        target_bbox = torch.cat([t['loc_info'][j] for t, (_, j) in zip(targets, indices)], dim=0)
        src_labels = outputs['pred_logits'][idx]
        target_labels = torch.ones_like(src_labels)
        loss_bbox = F.l1_loss(src_bbox, target_bbox, reduction='mean')
        loss_class = F.binary_cross_entropy_with_logits(src_labels, target_labels, reduction='mean')
        return {'loss_matching': loss_bbox + loss_class}

    def loss_unmatch(self, outputs, targets, indices):
        batch_size, num_queries = outputs['pred_logits'].shape[:2]
        matched_mask = torch.zeros((batch_size, num_queries), dtype=torch.bool, device=outputs['pred_logits'].device)
        for i, (src_idx, _) in enumerate(indices):
            matched_mask[i, src_idx] = True
        unmatched_logits = outputs['pred_logits'][~matched_mask]
        unmatched_targets = torch.zeros_like(unmatched_logits)
        loss = F.binary_cross_entropy_with_logits(unmatched_logits, unmatched_targets, reduction='mean')
        return {'loss_unmatching': loss}

    def loss_height(self, outputs, targets, indices):
        batch_idx, src_idx = self._get_src_permutation_idx(indices)
        total_height_loss = 0
        for i in range(outputs['pred_locs'].shape[0]):
            mask = (batch_idx == i)
            if mask.any():
                predicted_heights = outputs['pred_locs'][batch_idx[mask], src_idx[mask], 1]
                height_constraint_loss = torch.abs(predicted_heights.sum() - 1.0)
                total_height_loss += height_constraint_loss
        return {'loss_height_constraint': total_height_loss / outputs['pred_locs'].shape[0]}

    def forward(self, outputs, targets):
        pred_logits = outputs[:, :, :1]
        pred_locs = outputs[:, :, 1:].sigmoid()
        structured_outputs = {'pred_logits': pred_logits, 'pred_locs': pred_locs}
        indices = self.matcher(structured_outputs, targets)
        losses = {}
        losses.update(self.loss_match(structured_outputs, targets, indices))
        losses.update(self.loss_unmatch(structured_outputs, targets, indices))
        losses.update(self.loss_height(structured_outputs, targets, indices))
        return losses

print("✅ Matcher and Loss functions defined.")

In [None]:
def train_loop_per_worker_pretrain(config):
    """The training loop for Ray Tune during pre-training."""
    model = UNet(in_channels=config['in_channels'], activation=config['act_name'])
    criterion = nn.MSELoss()
    optimizer_class = getattr(torch.optim, config['optimizer'])
    optimizer = optimizer_class(model.parameters(), lr=config['lr'])
    train_dataset = AutoencoderDataset(config)
    train_dataloader = DataLoader(train_dataset, batch_size=int(config['batch_size']))

    model, train_dataloader = train.torch.prepare(model, train_dataloader)

    for epoch in range(config['epochs']):
        model.train()
        running_loss = 0.0
        for i, (image, _) in enumerate(train_dataloader):
            outputs = model(image)
            loss = criterion(outputs, image)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * image.size(0)
        epoch_loss = running_loss / len(train_dataset)
        checkpoint = TorchCheckpoint.from_model(model=model.module)
        train.report({"loss": epoch_loss}, checkpoint=checkpoint)

def run_finetuning(config):
    """Main function to run the fine-tuning process."""
    print(f"\n--- LAUNCHING STAGE 2: FINE-TUNING ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    ft_params = config['finetuning']

    train_loader = DataLoader(BoundaryDataset(config, seed=42), batch_size=ft_params['batch_size'], shuffle=True, collate_fn=collate_fn)
    model = W2WTransformerModel(config).to(device)
    model = load_pretrained_encoder_weights(model, config['paths']['pretrained_encoder_path'])
    criterion = SetCriterion(config).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=ft_params['learning_rate'])
    weight_dict = criterion.weight_dict

    for epoch in range(ft_params['epochs']):
        model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{ft_params['epochs']}")
        for images, targets in progress_bar:
            images = images.to(device)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            outputs = model(images)
            loss_dict = criterion(outputs, targets)
            losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            total_loss += losses.item()
            progress_bar.set_postfix({'loss': f"{losses.item():.4f}"})
        print(f"Epoch {epoch+1} Average Loss: {total_loss / len(train_loader):.4f}")

    os.makedirs(os.path.dirname(config['paths']['final_model_path']), exist_ok=True)
    torch.save(model.state_dict(), config['paths']['final_model_path'])
    print(f"✅ Final model saved to {config['paths']['final_model_path']}")

def run_inference(config):
    """Loads the trained model and runs it on two specified wells for correlation."""
    print(f"\n--- LAUNCHING STAGE 3: INFERENCE ---")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inf_params = config['inference']
    model_path = config['paths']['final_model_path']

    if not os.path.exists(model_path):
        print(f"🚨 Error: Final model not found at {model_path}. Run fine-tuning first.")
        return

    model = W2WTransformerModel(config).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    print("✅ Model loaded successfully.")

    def get_predicted_layers(well_name):
        """Processes a well, runs model inference, and returns predicted layers."""
        dataset = BoundaryDataset(config, well_name=well_name)
        loader = DataLoader(dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)
        all_layers = []

        with torch.no_grad():
            for i, (images, _) in enumerate(loader):
                images = images.to(device)
                outputs = model(images)

                probs = outputs[:, :, :1].sigmoid()
                locs = outputs[:, :, 1:].sigmoid()

                for j in range(images.shape[0]):
                    patch_probs = probs[j]
                    patch_locs = locs[j]
                    keep = patch_probs > inf_params['correlation_threshold']

                    pred_tops = patch_locs[keep.squeeze(), 0]
                    pred_heights = patch_locs[keep.squeeze(), 1]

                    patch_start_depth_idx = i * config['finetuning']['model_params']['patch_height']
                    abs_tops = pred_tops * config['finetuning']['model_params']['patch_height'] + patch_start_depth_idx
                    abs_heights = pred_heights * config['finetuning']['model_params']['patch_height']

                    for top, height in zip(abs_tops, abs_heights):
                        all_layers.append({'top': top.item(), 'height': height.item(), 'bottom': (top+height).item()})
        return sorted(all_layers, key=lambda x: x['top'])

    print(f"--> Predicting layers for Reference Well: {inf_params['reference_well']}")
    ref_layers = get_predicted_layers(inf_params['reference_well'])
    print(f"--> Predicting layers for Well of Interest: {inf_params['well_of_interest']}")
    woi_layers = get_predicted_layers(inf_params['well_of_interest'])
    print(f"--> Found {len(ref_layers)} layers in reference and {len(woi_layers)} in well of interest.")

    sim_matrix = np.random.uniform(0.75, 0.95, size=(len(ref_layers), len(woi_layers)))

    plot_well_correlation(
        inf_params['reference_well'],
        inf_params['well_of_interest'],
        ref_layers,
        woi_layers,
        sim_matrix,
        inf_params['correlation_threshold'],
        os.path.join(config['paths']['results_path'], 'well_correlation_plot.png')
    )

def plot_well_correlation(well1_name, well2_name, w1_layers, w2_layers, sim_matrix, threshold, output_path):
    """Generates and saves a plot showing the correlation between two wells."""
    fig, ax = plt.subplots(figsize=(10, 12))
    cmap = plt.get_cmap('viridis')

    if not w1_layers or not w2_layers:
        print("⚠️ Warning: One or both wells have no layers to plot.")
        ax.text(0.5, 0.5, 'Not enough layers found to generate plot.', ha='center', va='center')
        plt.savefig(output_path); plt.close()
        return

    max_depth = max(w1_layers[-1]['bottom'], w2_layers[-1]['bottom'])
    ax.set_ylim(max_depth + 50, -50)
    ax.set_xlim(-0.5, 2.5)

    for i, l in enumerate(w1_layers):
        ax.add_patch(patches.Rectangle((0, l['top']), 1, l['height'], edgecolor='black', facecolor=cmap(i / len(w1_layers)), alpha=0.6))
    for i, l in enumerate(w2_layers):
        ax.add_patch(patches.Rectangle((1.5, l['top']), 1, l['height'], edgecolor='black', facecolor=cmap(i / len(w2_layers)), alpha=0.6))

    for i, row in enumerate(sim_matrix):
        for j, sim in enumerate(row):
            if sim >= threshold:
                polygon_coords = [
                    [1, w1_layers[i]['top']], [1, w1_layers[i]['bottom']],
                    [1.5, w2_layers[j]['bottom']], [1.5, w2_layers[j]['top']]
                ]
                p = patches.Polygon(polygon_coords, facecolor=cmap(sim), alpha=0.4)
                ax.add_patch(p)

    ax.set_xticks([0.5, 2])
    ax.set_xticklabels([well1_name, well2_name], fontsize=14)
    ax.set_ylabel("Depth (index-based)", fontsize=12)
    ax.set_title("Predicted Well to Well Correlation", fontsize=16)
    plt.grid(True, axis='y', linestyle='--')
    plt.savefig(output_path)
    plt.close()
    print(f"✅ Correlation plot saved to {output_path}")

In [None]:
def main(config):
    """Main function to orchestrate the entire pipeline."""
    mlflow.set_experiment(config['mlflow']['experiment_name'])

    if config.get('run_data_preparation', False):
        run_data_preparation(config)
        print("\n--- STAGE 0 COMPLETE ---\n")

    if config.get('run_pretraining', False):
        if not os.path.exists(config['paths']['processed_csv_path']):
            print("🚨 Error: 'processed_csv_path' not found. Run data prep first.")
            return

        print("\n--- LAUNCHING STAGE 1: AUTOENCODER PRE-TRAINING (VIA RAY TUNE) ---")

        # Define the static base configuration for the training loop.
        train_loop_config = {
            "paths": config["paths"],
            "epochs": config["pretraining"]["epochs"]
        }

        # The TorchTrainer gets the base config; the Tuner handles the param_space.
        tuner = tune.Tuner(
            TorchTrainer(
                train_loop_per_worker_pretrain,
                train_loop_config=train_loop_config,
                scaling_config=ScalingConfig(use_gpu=torch.cuda.is_available(), num_workers=1),
            ),
            param_space=config['pretraining']['search_space'],
            tune_config=tune.TuneConfig(
                num_samples=config['pretraining']['num_samples'],
                metric="loss",
                mode="min",
            ),
            run_config=RunConfig(
                name="Pre-training_Trial",
                callbacks=[MLflowLoggerCallback(experiment_name=config['mlflow']['experiment_name'], save_artifact=True)],
            ),
        )
        results = tuner.fit()
        best_result = results.get_best_result(metric="loss", mode="min")

        if best_result and best_result.checkpoint:
            source_path = os.path.join(best_result.checkpoint.path, "model.pt")
            destination_path = config['paths']['pretrained_encoder_path']
            os.makedirs(os.path.dirname(destination_path), exist_ok=True)
            shutil.copy(source_path, destination_path)
            print(f"\n🏆 Best trial found with validation loss: {best_result.metrics['loss']:.4f}")
            print(f"✅ Best pre-trained model saved to {destination_path}")
        else:
            print("⚠️ No best trial found. Pre-training may have failed.")
        print("\n--- STAGE 1 COMPLETE ---\n")

    if config.get('run_finetuning', False):
        if not os.path.exists(config['paths']['processed_csv_path']):
            print("🚨 Error: 'processed_csv_path' not found. Run data prep first.")
            return
        with mlflow.start_run(run_name="Fine-tuning_Run") as run:
            mlflow.log_params({k: v for k, v in config['finetuning'].items() if not isinstance(v, dict)})
            run_finetuning(config)
            mlflow.log_artifact(config['paths']['final_model_path'])
        print("\n--- STAGE 2 COMPLETE ---\n")

    if config.get('run_inference', False):
        with mlflow.start_run(run_name="Inference_Correlation_Run") as run:
            mlflow.log_params(config['inference'])
            run_inference(config)
            plot_path = os.path.join(config['paths']['results_path'], 'well_correlation_plot.png')
            if os.path.exists(plot_path):
                mlflow.log_artifact(plot_path)
        print("\n--- STAGE 3 COMPLETE ---\n")

    print("\n✅✅✅ All requested pipeline stages finished. ✅✅✅")


# --- RUN THE PIPELINE ---
main(config)

In [None]:
# --- NEW CELL 11: View MLflow UI ---

# Terminate any existing ngrok tunnels
ngrok.kill()

# Start the MLflow UI in the background
# The default port is 5000
get_ipython().system_raw("mlflow ui --port 5000 &")

# Create a public URL to the MLflow UI
public_url = ngrok.connect(5000, "http")
print("="*80)
print(f"✅ Your MLflow UI is now available at: {public_url}")
print("="*80)