In [None]:
import torch
from torch.utils.data import DataLoader, ConcatDataset
import torch.nn as nn
import os 
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from torchinfo import summary
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm import tqdm
import matplotlib.pyplot as plt 
from utils.timeseriesdataset import TimeSeriesDataset
from utils.pad_batch import pad_batch, LABEL_PADDING_VALUE
from models.RegressionModel import RegressionModel
import pickle 
from pathlib import Path
from utils.load_data import load_data

DEVICE = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
EPOCHS = 35
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 2e-6

torch.cuda.empty_cache()
print('The model is running on:', DEVICE) 

# Create DataLoaders

In [None]:
train_instances = []
val_instances = []
test_instances = []

train_files = list(Path("../data/simulated_tracks").glob("*/train_instances.pkl"))
val_files = list(Path("../data/simulated_tracks").glob("*/val_instances.pkl"))
test_files = list(Path("../data/simulated_tracks").glob("*/test_instances.pkl"))

for file in train_files:
    with open(file, "rb") as f:
        train_instances += pickle.load(f)

for file in val_files:
    with open(file, "rb") as f:
        val_instances += pickle.load(f)

for file in test_files:
    with open(file, "rb") as f:
        test_instances += pickle.load(f)

print("Train data: ", len(train_instances), "Test data: ", len(test_instances), "Val data: ", len(val_instances))

In [None]:
conc_train = ConcatDataset(train_instances)
conc_val = ConcatDataset(val_instances)
conc_test = ConcatDataset(test_instances)

train_loader = DataLoader(conc_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)
test_loader = DataLoader(conc_test, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)
val_loader = DataLoader(conc_val, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)

print("DataLoader Sizes:", len(train_loader), len(test_loader), len(val_loader))

# DATA_PATH = "../data/simulated_tracks"
# filepaths = list(Path(DATA_PATH).rglob('*.parquet'))
# random.shuffle(filepaths)

# print("Number of files found:", len(filepaths))

# train_data = []
# test_data = []
# val_data = []

# train_data = [TimeSeriesDataset(filepath, augment=True) for filepath in filepaths[:int(len(filepaths)*0.7)]]
# test_data = [TimeSeriesDataset(filepath, augment=False) for filepath in filepaths[int(len(filepaths)*0.7):int(len(filepaths)*0.85)]]
# val_data = [TimeSeriesDataset(filepath, augment=False) for filepath in filepaths[int(len(filepaths)*0.85):]]

In [None]:
train_data, _, test_data = load_data() 

print("Train data: ", len(train_data))
# print("Val data: ", len(val_data))
print("Test data: ", len(test_data))


conc_train = ConcatDataset(train_data)
# conc_val = ConcatDataset(val_data)
conc_test = ConcatDataset(test_data)

training_loader = DataLoader(conc_train, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)
# val_loader = DataLoader(conc_val, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)
test_loader = DataLoader(conc_test, batch_size=BATCH_SIZE, shuffle=True, collate_fn=pad_batch)

# print("Data", len(training_loader), len(val_loader))

# Model
Load the model, optimizer, scheduler, loss

In [None]:
model = RegressionModel().to(DEVICE)

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('models/checkpoints/k_runs/runs_{}'.format(timestamp))
model_directory = os.path.join('models/checkpoints/k_model', 'model_{}'.format(timestamp))
    
print(summary(model, input_size=(BATCH_SIZE, 200, 10)))

continuous_loss_fn = nn.L1Loss(reduction='none')
best_val_loss = float("inf")

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=3)

# Training Functions

In [4]:
def train_one_epoch(model, optimizer, dataloader):
    model.train()
    running_loss = 0
    runs = 0

    for inputs, _,k_labels,_ in dataloader:

        inputs, k_labels = inputs.to(DEVICE), k_labels.to(DEVICE)
        mask = (k_labels != LABEL_PADDING_VALUE).float()

        outputs = model(inputs)
        outputs = outputs.squeeze(-1)
        total_loss = (continuous_loss_fn(outputs, k_labels) * mask).sum() / mask.sum()
                
        optimizer.zero_grad()
        total_loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
        optimizer.step()
        
        running_loss += total_loss.item()
        runs += 1

        progress_bar.update()

    return running_loss/runs

def evaluate_model(model, dataloader):
    model.eval()
    
    running_val_total = 0.0
    val_runs = 0

    with torch.no_grad():
        for inputs, _, k_labels,_ in dataloader:
            
            inputs, k_labels = inputs.to(DEVICE), k_labels.to(DEVICE)
            mask = (k_labels != LABEL_PADDING_VALUE).float()
            
            outputs = model(inputs)  
            outputs = outputs.squeeze(-1)
            total_loss = (continuous_loss_fn(outputs, k_labels) * mask).sum() / mask.sum()            
            running_val_total += total_loss.item()
            val_runs += 1
    
    return running_val_total / val_runs

# Train

In [None]:
model.load_state_dict(torch.load("/home/haidiri/Desktop/AnDiChallenge2024/src/models/checkpoints/k_model/model_20240923_161109/model_35"))
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=WEIGHT_DECAY)

os.makedirs(model_directory, exist_ok=True)

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch + 1))

    progress_bar = tqdm(total=len(training_loader), desc='Training', position=0)

    avg_training_loss = train_one_epoch(model, optimizer, training_loader)
    val_total_loss  = evaluate_model(model, val_loader)
    
    print(f'Training LOSS: K {avg_training_loss}\n'
          f'Validation LOSS: K {val_total_loss} \n')
    
    writer.add_scalars('Losses', {
        'Training K Loss': avg_training_loss,
        'Validation K Loss': val_total_loss,
        }, epoch + 1)

    writer.flush()
    
    if val_total_loss < best_val_loss:
        best_val_loss = val_total_loss
        best_model_path = os.path.join(model_directory, f'model_{epoch + 1}')
        torch.save(model.state_dict(), best_model_path)

    scheduler.step(val_total_loss)
    
progress_bar.close()
writer.close()

In [None]:
print("Best Validation Loss:", best_val_loss)
print("Best Model Path", best_model_path)

# Testing

In [None]:
model.load_state_dict(torch.load("/home/haidiri/Desktop/AnDiChallenge2024/models/optimal_weights/k_weights"))
model.eval()

running_test_total = 0.0
test_runs = 0.0

predictions = []
ground_truth = []

progress_bar = tqdm(total=len(test_loader), desc='Testing', position=0)

with torch.no_grad():
    for inputs, _, k_labels,_ in test_loader:
        
        inputs, k_labels = inputs.to(DEVICE), k_labels.to(DEVICE)

        mask = (k_labels != LABEL_PADDING_VALUE).float()
        outputs = model(inputs).squeeze(-1)
        total_loss = (continuous_loss_fn(outputs, k_labels) * mask).sum() / mask.sum()
        
        running_test_total += total_loss.item()
        test_runs += 1

        predictions.extend(outputs.cpu().numpy())
        ground_truth.extend(k_labels.cpu().numpy())
        progress_bar.update()


# Calculate average losses
avg_test_loss = running_test_total / test_runs
print(f'Average test loss: {avg_test_loss}')
progress_bar.close()

# Plot Predictions

In [6]:
from utils.postprocessing import smooth_series, median_filter_1d
import ruptures as rpt
import numpy as np 

In [7]:
def getCP_rpt(array, lower_limit=0, upper_limit=float("inf"), threshold=0.05):
    array = median_filter_1d(smooth_series(array, lower_limit=lower_limit, upper_limit=upper_limit))
    if np.max(array) != np.min(array):
        pred_series_scaled = (array - np.min(array)) / (np.max(array) - np.min(array))
    else:
        pred_series_scaled = np.ones(len(array)) * 0.5 #scale them to default value of 0.5

    algo = rpt.Pelt(model="l2", min_size=3, jump=1).fit(pred_series_scaled)
    cps = [0] + algo.predict(pen=0.3)

    remove = []
    for i in range(1, len(cps) - 1):
        left_mean = array[cps[i - 1]:cps[i]].mean()
        right_mean = array[cps[i]:cps[i + 1]].mean()        
        if abs(left_mean - right_mean) < threshold:
            remove.append(cps[i])
    
    cps = [cp for cp in cps if cp not in remove]

    return cps, array

def getCP_gt(array):
    cps = [0]
    for i in range(1, len(array)):
        if array[i-1] != array[i]:
            cps.append(i)

    return cps + [len(array)]

In [None]:
INDEX = 1086

# for idx in range(len(ground_truth)):

#     padding_starts = (ground_truth[idx] == LABEL_PADDING_VALUE).argmax() 

#     if padding_starts == 0:
#         padding_starts = 200

#     pred_k = predictions[idx][:padding_starts]
#     true_k = ground_truth[idx][:padding_starts]

#     diff = np.diff(pred_k)
#     diff[diff !=0] = 1
#     changepoints = int(np.sum(diff))

#     cp_pred, _ = getCP_rpt(pred_k, lower_limit=0, upper_limit=6, threshold=0.05)
#     cp_gt = getCP_gt(true_k)

#     if len(true_k) == 200 and changepoints == 3 and cp_pred == cp_gt and 0 not in true_k:
#         print(idx)
#         break

pred_k = predictions[INDEX]
true_k = ground_truth[INDEX]

# Save to numpy files
# np.save('pred_k_ruptures_example.npy', pred_k)
# np.save('true_k_ruptures_example.npy', true_k)

cp_pred, pred_k = getCP_rpt(pred_k, lower_limit=0, upper_limit=6, threshold=0.05)
cp_gt = getCP_gt(true_k)

cp_pred = cp_pred[1:-1]
print(cp_pred, cp_gt)

for i in cp_pred:
    plt.axvline(x=i, color='black', linestyle='--')
    
# plt.plot(cp_gt, true_k[cp_gt], color='black', linestyle='-')

plt.scatter([i for i in range(len(pred_k))], pred_k, color="red")
plt.scatter([i for i in range(len(true_k))], true_k, color="blue")
plt.show()

In [44]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from matplotlib.backends.backend_svg import FigureCanvasSVG

def plot_k_with_ruptures(pred_k, true_k, cp_pred, cp_gt, save_path=None):
    """
    Create publication quality plot comparing K predictions and ground truth with rupture points
    """
    # Create figure with specific DPI for precise control
    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    time = np.arange(len(pred_k))
    
    # Plot predictions with green open circles
    ax.plot(time, pred_k, color='#1B9E77', linestyle='none', 
            marker='o', markersize=6, markerfacecolor='none',
            markeredgecolor='#1B9E77', markeredgewidth=1.5,
            alpha=0.7)
    
    # Plot ground truth as separate horizontal lines between change points
    for i in range(len(cp_gt)-1):
        start_idx = cp_gt[i]
        end_idx = cp_gt[i+1]
        if end_idx == len(true_k):
            end_idx -= 1
        segment_value = true_k[start_idx]  # value for this segment
        ax.plot([start_idx, end_idx], [segment_value, segment_value],
                color='black', linewidth=2, alpha=1)
    
    # Plot rupture lines with increased prominence
    for cp in cp_pred:
        ax.axvline(x=cp, color='black', linestyle='--', 
                  linewidth=3, alpha=0.8)
    
    # Set y-axis limits with a small margin
    ymin = min(min(pred_k), min(true_k))
    ymax = max(max(pred_k), max(true_k))
    margin = (ymax - ymin) * 0.1  # 10% margin
    ax.set_ylim(ymin - margin, ymax + margin)
    
    # Configure axis and style
    ax.set_axisbelow(True)
    
    # Configure tick parameters to exactly match heatmap
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    
    # Set tick label sizes to exactly match heatmap
    ax.tick_params(axis='both', labelsize=32)
    
    ax.plot([], [], color='#1B9E77', linestyle='none',
            marker='o', markersize=12, markerfacecolor='none',
            markeredgecolor='#1B9E77', markeredgewidth=1.5,
            label='Prediction')
    
    ax.plot([], [], color='black', linewidth=2, label='Ground Truth')

    # Configure spines
    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels with exact same formatting as heatmap
    ax.set_xlabel('Time', fontsize=32)
    ax.set_ylabel(r'$K$', fontsize=32)
    
    # Configure legend with matching font size
    ax.legend(fontsize=28, frameon=True, loc='best')
    
    # Adjust layout with same padding
    fig.tight_layout(pad=1.0)
    
    # Save as SVG with same parameters
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight', 
                          pad_inches=0.1, format='svg')
    
    return fig, ax

# Example usage
plot_k_with_ruptures(pred_k, true_k, cp_pred, cp_gt, save_path='k_ruptures_example.svg')
plt.show()

# Example Plots for test set

In [None]:
for idx in range(len(ground_truth)):
    padding_starts = (ground_truth[idx] == LABEL_PADDING_VALUE).argmax() 
    if padding_starts == 0:
        padding_starts = 200

    pred_k = predictions[idx][:padding_starts]
    true_k = ground_truth[idx][:padding_starts]

    diff = np.diff(pred_k)
    diff[diff !=0] = 1
    changepoints = int(np.sum(diff))

    cp_pred, _ = getCP_rpt(pred_k, lower_limit=0, upper_limit=6, threshold=0.05)
    cp_gt = getCP_gt(true_k)

    if len(true_k) == 200 and changepoints == 1 and 0 in true_k:
        print(idx)
        break

In [None]:
example_bound = 


In [None]:
def plot_k_with_ruptures(pred, true, label_type, save_path=None):
    K_COLOR = '#1B9E77'      # Green
    ALPHA_COLOR = '#E69F00'  # Orange
    STATE_COLOR = '#9970AB'  # Purple
    
    if label_type =="alpha":
        COLOR = ALPHA_COLOR 
        ymin = 0
        ymax = 2
        label = r'$K$'
    elif label_type == "k":
        COLOR = K_COLOR
        ymin = 0
        ymax = 3
        label = r'$α$'
    elif label_type == "state":
        COLOR = STATE_COLOR
        ymin = 0
        ymax = 4
        label = r'$s$'
    else:
        raise ValueError

    def getCP_gt(array):
        cps = [0]
        for i in range(1, len(array)):
            if array[i-1] != array[i]:
                cps.append(i)

        return cps + [len(array)]

    fig = Figure(figsize=(8, 6), dpi=300)
    canvas = FigureCanvasSVG(fig)
    ax = fig.add_subplot(111)
    
    time = np.arange(len(pred))
    ax.plot(time, pred, color=COLOR, linestyle='none', 
            marker='o', markersize=6, markerfacecolor='none',
            markeredgecolor='#1B9E77', markeredgewidth=1.5,
            alpha=0.7)
    
    cp_gt = getCP_gt(true)

    for i in range(len(cp_gt)-1):
        start_idx = cp_gt[i]
        end_idx = cp_gt[i+1]
        if end_idx == len(true):
            end_idx -= 1
        segment_value = true[start_idx]  # value for this segment
        ax.plot([start_idx, end_idx], [segment_value, segment_value],
                color='black', linewidth=2, alpha=1)

    margin = (ymax - ymin) * 0.1  # 10% margin
    ax.set_ylim(ymin - margin, ymax + margin)
    ax.set_axisbelow(True)
    ax.tick_params(which='both', direction='out', length=6, width=1,
                  colors='black', pad=2)
    
    ax.tick_params(axis='both', labelsize=32)

    for spine in ax.spines.values():
        spine.set_visible(True)
        spine.set_linewidth(1)
        spine.set_color('black')
    
    # Set labels with exact same formatting as heatmap
    ax.set_xlabel('Time', fontsize=32)
    ax.set_ylabel(label, fontsize=32)
    
    # # Configure legend with matching font size
    # ax.legend(fontsize=28, frameon=True, loc='best')
    fig.tight_layout(pad=1.0)
    
    # Save as SVG with same parameters
    if save_path:
        canvas.print_figure(save_path, bbox_inches='tight', 
                          pad_inches=0.1, format='svg')
    
    return fig, ax