In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, Subset
import numpy as np
import pandas as pd
import os
import cv2
from tqdm import tqdm
from scipy.interpolate import CubicSpline
import time  
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import font_manager, rcParams
import warnings
from collections import defaultdict
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

PAST = 100
FUTURE = 10
TOTAL_WINDOW = PAST + FUTURE
SOH_MIN_EVAL = 0.8
SOH_MAX_EVAL = 1.0

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction_ratio),
            nn.ReLU(inplace=True),
            nn.Linear(in_channels // reduction_ratio, in_channels)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x).view(x.size(0), -1))
        max_out = self.fc(self.max_pool(x).view(x.size(0), -1))
        out = avg_out + max_out
        return self.sigmoid(out).unsqueeze(2).unsqueeze(3)


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        assert kernel_size in (3, 7), 
        padding = 3 if kernel_size == 7 else 1
        
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        x_att = self.conv(x_cat)
        return self.sigmoid(x_att)

class CBAM(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction_ratio)
        self.spatial_attention = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.channel_attention(x)
        x = x * self.spatial_attention(x)
        return x

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, max(1, channels // reduction)),
            nn.ReLU(inplace=True),
            nn.Linear(max(1, channels // reduction), channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class Combined_Temporal_Module(nn.Module):
    def __init__(self):
        super(Combined_Temporal_Module, self).__init__()
        self.lstm = nn.LSTM(
            input_size=15, 
            hidden_size=128,
            num_layers=2,
            batch_first=True,
            dropout=0.2,
            bidirectional=True  
        )
        self.upconv1 = nn.ConvTranspose2d(256, 512, kernel_size=4, stride=4)
        self.bn1 = nn.BatchNorm2d(512)
        self.conv1_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn1_1 = nn.BatchNorm2d(512)
        self.conv1_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn1_2 = nn.BatchNorm2d(512)
        self.upconv2 = nn.ConvTranspose2d(512, 512, kernel_size=4, stride=4)
        self.bn2 = nn.BatchNorm2d(512)
        self.conv2_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn2_1 = nn.BatchNorm2d(512)
        self.conv2_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn2_2 = nn.BatchNorm2d(512)
        self.relu = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout2d(0.1)

    def forward(self, x):
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = self.upconv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv1_1(x)
        x = self.bn1_1(x)
        x = self.relu(x)
        x = self.conv1_2(x)
        x = self.bn1_2(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.upconv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv2_1(x)
        x = self.bn2_1(x)
        x = self.relu(x)
        x = self.conv2_2(x)
        x = self.bn2_2(x)
        x = self.relu(x)
        return x

class Visual_Module(nn.Module):
    def __init__(self):
        super(Visual_Module, self).__init__()
        self.temporal_module = Combined_Temporal_Module()
        self.se_img = SEBlock(512)
        self.se_seq = SEBlock(512)
        self.cbam = CBAM(1024, reduction_ratio=16, kernel_size=7)
        self.branch_att = nn.Sequential(
            nn.Linear(512 + 512, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 2)
        )
        self.encoder_1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.encoder_2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.encoder_3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.encoder_4 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.encoder_5 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.upconv1 = nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2)
        self.bn_up1 = nn.BatchNorm2d(256)
        self.decoder_1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.bn_up2 = nn.BatchNorm2d(128)
        self.decoder_2 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.upconv3 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.bn_up3 = nn.BatchNorm2d(64)
        self.decoder_3 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.upconv4 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.bn_up4 = nn.BatchNorm2d(32)
        self.decoder_4 = nn.Sequential(
            nn.Conv2d(64, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.output_layer = nn.Conv2d(32, 1, kernel_size=1)

    def forward(self, x, s, return_att=False, branch_gates=None, return_stages=False):
        stages = {} if return_stages else None
        x_input = x
        if return_stages:
            stages["input"] = x_input
        x = x.unsqueeze(1)
        x1 = self.encoder_1(x)
        if return_stages:
            stages["enc1"] = x1
        x2_in = self.pool(x1)
        x2 = self.encoder_2(x2_in)
        if return_stages:
            stages["enc2"] = x2
        x3_in = self.pool(x2)
        x3 = self.encoder_3(x3_in)
        if return_stages:
            stages["enc3"] = x3
        x4_in = self.pool(x3)
        x4 = self.encoder_4(x4_in)
        if return_stages:
            stages["enc4"] = x4
        x5_in = self.pool(x4)
        x5 = self.encoder_5(x5_in)
        if return_stages:
            stages["enc5"] = x5
        s = self.temporal_module(s)
        if return_stages:
            stages["seq_branch"] = s
        x5 = self.se_img(x5)
        s = self.se_seq(s)
        gap = nn.functional.adaptive_avg_pool2d
        g_img = torch.flatten(gap(x5, 1), 1)
        g_seq = torch.flatten(gap(s, 1), 1)
        att_logits = self.branch_att(torch.cat([g_img, g_seq], dim=1))
        att_weights = torch.softmax(att_logits, dim=1)
        if branch_gates is not None:
            if not torch.is_tensor(branch_gates):
                branch_gates = torch.tensor(
                    branch_gates, dtype=att_weights.dtype, device=att_weights.device
                )
            branch_gates = branch_gates.view(1, 2).expand_as(att_weights)
            att_weights = att_weights * branch_gates
            att_weights = att_weights / (att_weights.sum(dim=1, keepdim=True) + 1e-8)

        w_img = att_weights[:, 0].view(-1, 1, 1, 1)
        w_seq = att_weights[:, 1].view(-1, 1, 1, 1)
        x5 = x5 * w_img
        s = s * w_seq
        x_s = torch.cat([x5, s], dim=1)
        if return_stages:
            stages["before_cbam"] = x_s
        x_s = self.cbam(x_s)
        if return_stages:
            stages["after_cbam"] = x_s
        x6 = self.upconv1(x_s)
        x6 = self.bn_up1(x6)
        x6 = torch.cat([x6, x4], dim=1)
        x6 = self.decoder_1(x6)
        if return_stages:
            stages["dec1"] = x6
        x7 = self.upconv2(x6)
        x7 = self.bn_up2(x7)
        x7 = torch.cat([x7, x3], dim=1)
        x7 = self.decoder_2(x7)
        if return_stages:
            stages["dec2"] = x7
        x8 = self.upconv3(x7)
        x8 = self.bn_up3(x8)
        x8 = torch.cat([x8, x2], dim=1)
        x8 = self.decoder_3(x8)
        if return_stages:
            stages["dec3"] = x8
        x9 = self.upconv4(x8)
        x9 = self.bn_up4(x9)
        x9 = torch.cat([x9, x1], dim=1)
        x9 = self.decoder_4(x9)
        if return_stages:
            stages["dec4"] = x9
        out = self.output_layer(x9)
        out = out.squeeze(1)
        if return_stages:
            stages["output"] = out.unsqueeze(1)
        if not return_att and not return_stages:
            return out
        if return_att:
            with torch.no_grad():
                avg_out = torch.mean(x_s, dim=1, keepdim=True)
                max_out, _ = torch.max(x_s, dim=1, keepdim=True)
                spatial_input = torch.cat([avg_out, max_out], dim=1)
                sa = self.cbam.spatial_attention(spatial_input)
                sa_up = nn.functional.interpolate(
                    sa, size=x_input.shape[-2:], mode='bilinear', align_corners=False
                )
                sa_up = sa_up.squeeze(1).detach().cpu()

            att_dict = {
                "branch_weights": att_weights.detach().cpu(),
                "spatial_map": sa_up
            }

        if return_att and not return_stages:
            return out, att_dict
        elif not return_att and return_stages:
            return out, stages
        else:
            return out, att_dict, stages

def train_model(train_loader, val_loader, model, epochs,
                criterion, optimizer, scheduler, save_path):
    train_loss = []
    val_loss = []
    model.to(device)

    best_val_loss = float("inf")
    patience = 15
    patience_counter = 0
    
    total_train_time = 0
    epoch_times = []

    for epoch in range(epochs):
        epoch_start_time = time.time()
        
        model.train()
        batch_loss = []
        train_loader_tqdm = tqdm(train_loader,
                                 desc=f"Training Epoch {epoch + 1}/{epochs}")
        for img_in, seq_in, targets in train_loader_tqdm:
            img_in = img_in.to(device)
            seq_in = seq_in.to(device)
            targets = targets.to(device)

            outputs = model(img_in, seq_in)
            optimizer.zero_grad()
            loss = criterion(outputs, targets)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            batch_loss.append(loss.item())
            train_loader_tqdm.set_postfix({'loss': f'{loss.item():.6f}'})

        train_loss.append(np.mean(batch_loss))

        model.eval()
        val_batch_loss = []
        with torch.no_grad():
            for img_in, seq_in, targets in val_loader:
                img_in = img_in.to(device)
                seq_in = seq_in.to(device)
                targets = targets.to(device)

                outputs = model(img_in, seq_in)
                loss = criterion(outputs, targets)
                val_batch_loss.append(loss.item())

        val_loss.append(np.mean(val_batch_loss))

        epoch_time = time.time() - epoch_start_time
        epoch_times.append(epoch_time)
        total_train_time += epoch_time

        if scheduler is not None:
            scheduler.step(val_loss[-1])
            current_lr = optimizer.param_groups[0]['lr']
        else:
            current_lr = None

        if val_loss[-1] < best_val_loss:
            best_val_loss = val_loss[-1]
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
            lr_str = f", LR: {current_lr:.2e}" if current_lr is not None else ""
        else:
            patience_counter += 1
            lr_str = f", LR: {current_lr:.2e}" if current_lr is not None else ""

        if patience_counter >= patience:
            break

    avg_epoch_time = np.mean(epoch_times)
    total_train_time_min = total_train_time / 60

    return train_loss, val_loss, total_train_time_min, avg_epoch_time
