In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score
import copy
import warnings
warnings.filterwarnings('ignore')


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# image and longitudinal outcome path
OUTCOME_CSV_PATH = '/content/drive/MyDrive/patients_742_selected.csv'
IMAGING_PKL_PATH = "/content/drive/MyDrive/x_with_coord.pkl"

PATIENT_ID_COL = 'id'
TIME_COL = 'Years_bl'
OUTCOME_COLS = ['MMSE', 'CDRSB', 'LDELTOTAL']
STATIC_TABULAR_COLS = ['AGE', 'GENDER', 'PTEDUCAT']

N_SPLITS = 5
RANDOM_STATE = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {DEVICE}")
print("-" * 30 + "\n")


print("--- step 2: image loading and pre-processing ---")
imaging_df = pd.read_pickle(IMAGING_PKL_PATH)
df = imaging_df.copy()
block = 2
signal_cols = df.columns[3:]

df["Xb"] = df.x // block
df["Yb"] = df.y // block
df["Zb"] = df.z // block

print("Mean‑pooling …")
pooled = df.groupby(["Xb", "Yb", "Zb"])[signal_cols].mean()

pooled = pooled.reset_index()
display(pooled)


device: cuda
------------------------------

--- step 2: image loading and pre-processing ---
Mean‑pooling …


Unnamed: 0,Xb,Yb,Zb,002_S_0729_2006-07-17_5.csv,002_S_0782_2006-08-14_3.csv,002_S_0954_2006-10-10_7.csv,002_S_1070_2006-11-28_11.csv,002_S_1155_2006-12-14_9.csv,002_S_1268_2007-02-14_11.csv,002_S_2010_2010-06-24_11.csv,...,941_S_2060_2010-09-08_733.csv,941_S_4036_2011-05-10_735.csv,941_S_4187_2011-06-22_734.csv,941_S_4377_2012-01-04_740.csv,941_S_4420_2012-03-28_739.csv,941_S_4764_2012-06-01_742.csv,941_S_6017_2017-05-17_822.csv,941_S_6052_2017-07-20_825.csv,941_S_6068_2017-08-21_831.csv,941_S_6345_2018-05-10_839.csv
0,1,37,44,-2.792912,-2.378889,-1.773856,-2.335286,0.596719,-1.765281,-2.833100,...,-1.701152,0.875000,0.201460,-2.583830,-2.595766,-1.767830,-2.355316,0.226981,1.646702,-2.352754
1,1,37,45,-2.808298,-2.378889,-1.773856,-2.329250,0.901965,-1.765281,-2.829968,...,-1.701152,0.155349,0.600329,-2.571290,-2.511538,-1.328341,-2.355316,0.690027,1.329351,-2.352754
2,1,37,46,-2.789328,-2.378889,-1.773856,-2.329084,0.651638,-1.765281,-2.832926,...,-1.701152,-1.182283,0.757916,-2.688458,-2.369983,-2.284303,-2.355316,-0.780878,1.504705,-2.352754
3,1,38,43,-2.791528,-2.375753,-1.773856,-2.517997,0.632981,-1.765281,-2.769821,...,-1.701152,1.011001,0.763914,-0.858805,-2.498166,-2.478895,-2.352305,1.436756,1.877898,-2.345762
4,1,38,44,-0.863732,-2.348688,-1.773856,-2.496181,1.103007,-1.765281,-2.746532,...,-1.701152,0.858369,0.688716,-1.114168,-2.000347,-0.240278,-2.330434,1.394198,1.328678,-2.335502
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
239163,77,49,50,-2.728616,-2.360885,-1.773856,-2.321014,-2.579449,-1.765281,-2.817723,...,-1.761741,-2.519014,-2.863345,-2.497635,-2.421602,-2.173714,-2.367502,-2.477654,-2.565042,-2.344331
239164,77,49,51,-2.726615,-2.380032,-1.773856,-2.313229,-2.566476,-1.765281,-2.826856,...,-1.780595,-2.517990,-2.874205,-2.528827,-2.411746,-2.184370,-2.366173,-2.471008,-2.574568,-2.326821
239165,77,49,52,-2.722288,-2.413499,-1.773856,-2.313831,-2.569669,-1.765281,-2.828936,...,-1.737624,-2.516753,-2.876511,-2.547080,-2.399212,-2.199603,-2.369792,-2.468890,-2.563396,-2.334988
239166,77,51,40,-2.762474,-2.382130,-1.773856,-2.328615,-2.588504,-1.765231,-2.810227,...,-1.701075,-2.520640,-2.848475,-2.551446,-2.426327,-2.171802,-2.352401,-2.462863,-2.548314,-2.340197


In [4]:


imaging_df = pooled

# Infer image dimensions from coordinates
coord_cols = ['Xb', 'Yb', 'Zb']
x_coords_all, y_coords_all, z_coords_all = imaging_df['Xb'].values, imaging_df['Yb'].values, imaging_df['Zb'].values
x_unique_sorted, y_unique_sorted, z_unique_sorted = np.sort(np.unique(x_coords_all)), np.sort(np.unique(y_coords_all)), np.sort(np.unique(z_coords_all))
dim_x, dim_y, dim_z = len(x_unique_sorted), len(y_unique_sorted), len(z_unique_sorted)
print(f"3d image dimension: (X: {dim_x}, Y: {dim_y}, Z: {dim_z})")

# Create coordinate-to-index maps
x_map, y_map, z_map = {val: i for i, val in enumerate(x_unique_sorted)}, {val: i for i, val in enumerate(y_unique_sorted)}, {val: i for i, val in enumerate(z_unique_sorted)}
x_indices, y_indices, z_indices = np.array([x_map[val] for val in x_coords_all]), np.array([y_map[val] for val in y_coords_all]), np.array([z_map[val] for val in z_coords_all])

# Reshape voxel data into 3D images
patient_id_cols = imaging_df.columns[3:]
num_patients = len(patient_id_cols)
all_patient_images = np.zeros((num_patients, dim_x, dim_y, dim_z), dtype=np.float32)

for i, patient_col_name in enumerate(patient_id_cols):
    all_patient_images[i, x_indices, y_indices, z_indices] = imaging_df[patient_col_name].values

# Add channel dimension for PyTorch -> (N, 1, X, Y, Z)
X_data_images = np.expand_dims(all_patient_images, axis=1)
print(f"image shape: {X_data_images.shape}")

# Create a mapping DataFrame from the image data
imaging_ids_df = pd.DataFrame({
    'imaging_patient_id': patient_id_cols,
    PATIENT_ID_COL: np.arange(1, num_patients + 1), # patient ID starting from1
    'image_idx': np.arange(num_patients)
})
print("-" * 30 + "\n")


3d image dimension: (X: 77, Y: 92, Z: 74)
image shape: (742, 1, 77, 92, 74)
------------------------------



In [5]:
print("--- step 3: merge longitudinal and image ---")
outcome_df = pd.read_csv(OUTCOME_CSV_PATH)
merged_df = pd.merge(imaging_ids_df, outcome_df, on=PATIENT_ID_COL, how='inner')
print(f" {len(merged_df)} data in total")
unique_patient_ids = merged_df[PATIENT_ID_COL].unique()
print(f"{len(unique_patient_ids)} patients in total")
print("-" * 30 + "\n")


--- step 3: merge longitudinal and image ---
 3826 data in total
742 patients in total
------------------------------



In [10]:
merged_df

Unnamed: 0,imaging_patient_id,id,image_idx,MMSE,CDRSB,LDELTOTAL,Years_bl,AGE,GENDER,PTEDUCAT,APOE4
0,002_S_0729_2006-07-17_5.csv,1,0,27.0,0.5,1.0,0.000000,65.1,1,16,1
1,002_S_0729_2006-07-17_5.csv,1,0,27.0,0.5,,0.558522,65.1,1,16,1
2,002_S_0782_2006-08-14_3.csv,2,1,29.0,0.5,8.0,0.000000,81.6,0,16,0
3,002_S_0782_2006-08-14_3.csv,2,1,30.0,1.0,,0.594114,81.6,0,16,0
4,002_S_0782_2006-08-14_3.csv,2,1,28.0,1.0,6.0,1.078710,81.6,0,16,0
...,...,...,...,...,...,...,...,...,...,...,...
3821,941_S_6052_2017-07-20_825.csv,740,739,27.0,2.5,7.0,0.928131,88.1,1,16,1
3822,941_S_6068_2017-08-21_831.csv,741,740,27.0,3.0,1.0,0.000000,75.7,0,12,1
3823,941_S_6068_2017-08-21_831.csv,741,740,25.0,4.0,0.0,0.955510,75.7,0,12,1
3824,941_S_6068_2017-08-21_831.csv,741,740,26.0,3.5,1.0,1.968510,75.7,0,12,1


In [11]:
class LongitudinalPatientDataset(Dataset):
    def __init__(self, df, patient_ids, all_images, tabular_cols, outcome_cols, scalers=None):
        self.df = df
        self.patient_ids = patient_ids
        self.all_images = all_images
        self.tabular_cols = tabular_cols
        self.outcome_cols = outcome_cols
        self.scalers = scalers

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        patient_id = self.patient_ids[idx]
        patient_df = self.df[self.df[PATIENT_ID_COL] == patient_id].copy()

        # image loading
        image_idx = patient_df['image_idx'].iloc[0]
        image_tensor = torch.from_numpy(self.all_images[image_idx].astype(np.float32))
        min_val, max_val = torch.min(image_tensor), torch.max(image_tensor)
        image_tensor = (image_tensor - min_val) / (max_val - min_val) if (max_val - min_val) > 1e-6 else torch.zeros_like(image_tensor)

        # linear covariates
        tabular_vals_1d = patient_df[self.tabular_cols].iloc[0].values
        tabular_vals_2d = tabular_vals_1d.reshape(1, -1)

        if self.scalers and 'tabular' in self.scalers:
            tabular_vals_2d = self.scalers['tabular'].transform(tabular_vals_2d)

        tabular_tensor = torch.tensor(tabular_vals_2d.astype(np.float32)).squeeze(0)

        # longitudinal_data
        longitudinal_data = {}
        patient_outcome_df = patient_df[[TIME_COL] + self.outcome_cols]
        for outcome in self.outcome_cols:
            times = torch.tensor(patient_outcome_df[TIME_COL].values, dtype=torch.float32)
            values = patient_outcome_df[[outcome]].values
            if self.scalers and 'outcomes' in self.scalers and outcome in self.scalers['outcomes']:
                values = self.scalers['outcomes'][outcome].transform(values)
            values = torch.tensor(values.flatten(), dtype=torch.float32)
            longitudinal_data[outcome] = (times, values)

        return {'image': image_tensor, 'tabular': tabular_tensor, 'longitudinal': longitudinal_data}



class MultiOutcomeCNN(nn.Module):
    def __init__(self, num_tabular_features, num_outcomes=3):
        super(MultiOutcomeCNN, self).__init__()
        output_dim = num_outcomes * 2

        # cnn
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 32, 3, 1, 1), nn.BatchNorm3d(32), nn.ReLU(True), nn.MaxPool3d(2, 2),
            nn.Conv3d(32, 64, 3, 1, 1), nn.BatchNorm3d(64), nn.ReLU(True), nn.MaxPool3d(2, 2),
            nn.Conv3d(64, 128, 3, 1, 1), nn.BatchNorm3d(128), nn.ReLU(True), nn.MaxPool3d(2, 2),
            nn.Conv3d(128, 256, 3, 1, 1), nn.BatchNorm3d(256), nn.ReLU(True), nn.MaxPool3d(2, 2)
        )
        self.pooling = nn.AdaptiveAvgPool3d((1, 1, 1))
        self.flatten = nn.Flatten()

        # fc
        self.image_processor = nn.Sequential(
            nn.Linear(256, 64), nn.ReLU(True), nn.Dropout(p=0.3),
            nn.Linear(64, output_dim) # output dim 6
        )

        self.tabular_processor = nn.Linear(num_tabular_features, output_dim, bias=True)

        self.apply(self._initialize_weights)

    def _initialize_weights(self, m):
        if isinstance(m, (nn.Conv3d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm3d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

    def forward(self, x_img, x_tab):
        img_features = self.flatten(self.pooling(self.encoder(x_img)))
        params_from_image = self.image_processor(img_features)
        params_from_tabular = self.tabular_processor(x_tab)
        # add up linear and non-linear
        final_params = params_from_image + params_from_tabular
        return final_params


class LongitudinalLossUnsimplified(nn.Module):
    def __init__(self, outcome_cols):
        super().__init__()
        self.outcome_cols = outcome_cols
        self.k_dim = 2  # each outcome has 2 output

    def forward(self, f_pred_batch, y_long_batch):
        batch_loss = 0.0
        for i in range(len(y_long_batch)):
            f_pred_patient = f_pred_batch[i]
            patient_data = y_long_batch[i]['longitudinal']
            patient_loss_sum = 0.0

            for j, outcome_name in enumerate(self.outcome_cols):
                f_j = f_pred_patient[j * self.k_dim : (j + 1) * self.k_dim]
                a_j, b_j = f_j[0], f_j[1]

                times, values = patient_data[outcome_name]
                times, values = times.to(DEVICE), values.to(DEVICE)


                # 1. how many k for each ij(mij)
                num_valid_obs = (~torch.isnan(values)).sum().item()

                if num_valid_obs == 0:
                    continue

                w_ij = 1.0 / num_valid_obs


                sum_of_squared_errors = 0.0

                for k in range(len(times)):
                    y_ijk = values[k]

                    if torch.isnan(y_ijk):
                        continue

                    T_ijk = times[k]
                    predicted_value = a_j + b_j * T_ijk
                    squared_error = (y_ijk - predicted_value)**2
                    sum_of_squared_errors += squared_error

                patient_loss_sum += w_ij * sum_of_squared_errors

            batch_loss += patient_loss_sum

        return batch_loss / len(y_long_batch)

In [12]:
kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)
fold_results = []

for fold, (train_idx, val_idx) in enumerate(kf.split(unique_patient_ids)):
    print(f"\n========== fold {fold+1}/{N_SPLITS} ==========")

    # 1. split based on subjecy
    train_ids = unique_patient_ids[train_idx]
    val_ids = unique_patient_ids[val_idx]
    # train_df 和 val_df now unscaled
    train_df = merged_df[merged_df[PATIENT_ID_COL].isin(train_ids)].copy()
    val_df = merged_df[merged_df[PATIENT_ID_COL].isin(val_ids)].copy()

    # linear covariate scaler
    tabular_scaler = StandardScaler()
    tabular_scaler.fit(train_df[STATIC_TABULAR_COLS]) # 只fit

    # y scalers
    outcome_scalers = {}
    for col in OUTCOME_COLS:
        scaler = StandardScaler()
        scaler.fit(train_df[[col]].dropna())
        outcome_scalers[col] = scaler

    # 3.create Dataset and DataLoader
    all_scalers = {'tabular': tabular_scaler, 'outcomes': outcome_scalers}
    train_dataset = LongitudinalPatientDataset(train_df, train_ids, X_data_images, STATIC_TABULAR_COLS, OUTCOME_COLS, all_scalers)
    val_dataset = LongitudinalPatientDataset(val_df, val_ids, X_data_images, STATIC_TABULAR_COLS, OUTCOME_COLS, all_scalers)

    def collate_fn(batch):
        images = torch.stack([item['image'] for item in batch])
        tabulars = torch.stack([item['tabular'] for item in batch])
        longitudinals = [item for item in batch]
        return {'image': images, 'tabular': tabulars, 'longitudinal': longitudinals}

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)

    """model = MultiOutcomeCNN(num_tabular_features=len(STATIC_TABULAR_COLS), num_outcomes=len(OUTCOME_COLS)).to(DEVICE)
    PATH_TO_WEIGHTS = "best_model.pth" # <--- CHANGE THIS TO YOUR FILE PATH
    model.load_state_dict(torch.load(PATH_TO_WEIGHTS, map_location=DEVICE))
    # 4. Set up the criterion as before
    criterion = LongitudinalLossUnsimplified(outcome_cols=OUTCOME_COLS)
    optimizer = optim.AdamW(model.parameters(), lr=5e-6, weight_decay=1e-5)"""


    model = MultiOutcomeCNN(num_tabular_features=len(STATIC_TABULAR_COLS), num_outcomes=len(OUTCOME_COLS)).to(DEVICE)
    criterion = LongitudinalLossUnsimplified(outcome_cols=OUTCOME_COLS)
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    #optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-4)

    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=20, min_lr=1e-8)

    # 5.training and validation
    best_val_loss = float('inf')
    best_val_r2 = {outcome: -float('inf') for outcome in OUTCOME_COLS}
    best_train_r2 = {outcome: -float('inf') for outcome in OUTCOME_COLS}
    epochs_no_improve = 0
    num_epochs = 500
    early_stopping_patience = 100

    for epoch in range(num_epochs):
        print(f"--- Epoch [{epoch+1:03d}/{num_epochs}] ---")
        model.train()
        train_loss_sum = 0
        train_r2_data = {outcome: {'true': [], 'pred_params': [], 'times': []} for outcome in OUTCOME_COLS}

        for batch in train_loader:
            images, tabulars, longitudinals = batch['image'].to(DEVICE), batch['tabular'].to(DEVICE), batch['longitudinal']
            optimizer.zero_grad()
            f_pred = model(images, tabulars)
            loss = criterion(f_pred, longitudinals)
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.item()

            for i in range(len(longitudinals)):
                f_pred_patient = f_pred[i].detach()
                for j, outcome in enumerate(OUTCOME_COLS):
                    times, values = longitudinals[i]['longitudinal'][outcome]
                    mask = ~torch.isnan(values)
                    if mask.sum() > 0:
                        train_r2_data[outcome]['true'].extend(values[mask].cpu().numpy())
                        train_r2_data[outcome]['times'].extend(times[mask].cpu().numpy())
                        params_j = f_pred_patient[j*2:(j+1)*2]
                        train_r2_data[outcome]['pred_params'].extend([params_j.cpu()] * mask.sum())

        model.eval()
        val_loss_sum = 0
        val_r2_data = {outcome: {'true': [], 'pred_params': [], 'times': []} for outcome in OUTCOME_COLS}
        with torch.no_grad():
            for batch in val_loader:
                images, tabulars, longitudinals = batch['image'].to(DEVICE), batch['tabular'].to(DEVICE), batch['longitudinal']
                f_pred = model(images, tabulars)
                loss = criterion(f_pred, longitudinals)
                val_loss_sum += loss.item()
                for i in range(len(longitudinals)):
                    f_pred_patient = f_pred[i]
                    for j, outcome in enumerate(OUTCOME_COLS):
                        times, values = longitudinals[i]['longitudinal'][outcome]
                        mask = ~torch.isnan(values)
                        if mask.sum() > 0:
                            val_r2_data[outcome]['true'].extend(values[mask].cpu().numpy())
                            val_r2_data[outcome]['times'].extend(times[mask].cpu().numpy())
                            params_j = f_pred_patient[j*2:(j+1)*2]
                            val_r2_data[outcome]['pred_params'].extend([params_j.cpu()] * mask.sum())

        avg_train_loss = train_loss_sum / len(train_loader)
        avg_val_loss = val_loss_sum / len(val_loader)
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Loss -> Train: {avg_train_loss:.6f}, Val: {avg_val_loss:.6f} | LR: {current_lr:.1e}")

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UserWarning, message="X does not have valid feature names")

            current_train_r2 = {}
            print("  R² (Train) -> ", end="")
            for outcome, data in train_r2_data.items():
                if len(data['true']) > 1:
                    y_pred = [p[0] + p[1] * t for p, t in zip(data['pred_params'], data['times'])]
                    y_true_unscaled = outcome_scalers[outcome].inverse_transform(np.array(data['true']).reshape(-1, 1))
                    y_pred_unscaled = outcome_scalers[outcome].inverse_transform(np.array(y_pred).reshape(-1, 1))
                    r2 = r2_score(y_true_unscaled, y_pred_unscaled)
                    current_train_r2[outcome] = r2
                    print(f"{outcome}: {r2:.4f}; ", end="")
            print()

            current_val_r2 = {}
            print("  R² (Val)   -> ", end="")
            for outcome, data in val_r2_data.items():
                if len(data['true']) > 1:
                    y_pred = [p[0] + p[1] * t for p, t in zip(data['pred_params'], data['times'])]
                    y_true_unscaled = outcome_scalers[outcome].inverse_transform(np.array(data['true']).reshape(-1, 1))
                    y_pred_unscaled = outcome_scalers[outcome].inverse_transform(np.array(y_pred).reshape(-1, 1))
                    r2 = r2_score(y_true_unscaled, y_pred_unscaled)
                    current_val_r2[outcome] = r2
                    print(f"{outcome}: {r2:.4f}; ", end="")
            print()

        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            best_val_r2 = current_val_r2
            best_train_r2 = current_train_r2
            epochs_no_improve = 0
            torch.save(model.state_dict(), f"best_model_fold_{fold+1}.pth")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= early_stopping_patience:
                print(f"\n提前停止于周期 {epoch+1}.")
                break

    fold_results.append({'loss': best_val_loss, 'val_r2': best_val_r2, 'train_r2': best_train_r2})
    print(f"\n折叠 {fold+1} 完成。最佳验证损失: {best_val_loss:.6f}")
    print(f"对应的验证集R²: {best_val_r2}")
    print(f"对应的训练集R²: {best_train_r2}")

print("\n" + "="*30)
print("--- 交叉验证结果总结 ---")
print("="*30)
avg_loss = np.mean([res['loss'] for res in fold_results])
std_loss = np.std([res['loss'] for res in fold_results])
print(f"平均最佳验证损失: {avg_loss:.6f} ± {std_loss:.6f}")

for outcome in OUTCOME_COLS:
    val_r2_scores = [res['val_r2'].get(outcome, np.nan) for res in fold_results]
    train_r2_scores = [res['train_r2'].get(outcome, np.nan) for res in fold_results]
    print(f"  - {outcome} 平均 Val R²:   {np.nanmean(val_r2_scores):.4f} ± {np.nanstd(val_r2_scores):.4f}")
    print(f"  - {outcome} 平均 Train R²: {np.nanmean(train_r2_scores):.4f} ± {np.nanstd(train_r2_scores):.4f}")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  R² (Val)   -> MMSE: 0.0950; CDRSB: 0.0359; LDELTOTAL: 0.2768; 
--- Epoch [331/500] ---
Loss -> Train: 2.432549, Val: 3.165931 | LR: 1.0e-08
  R² (Train) -> MMSE: 0.1272; CDRSB: 0.0954; LDELTOTAL: 0.2597; 
  R² (Val)   -> MMSE: 0.0905; CDRSB: 0.0362; LDELTOTAL: 0.2858; 
--- Epoch [332/500] ---
Loss -> Train: 2.439197, Val: 3.172142 | LR: 1.0e-08
  R² (Train) -> MMSE: 0.1135; CDRSB: 0.1021; LDELTOTAL: 0.2653; 
  R² (Val)   -> MMSE: 0.0944; CDRSB: 0.0369; LDELTOTAL: 0.2801; 
--- Epoch [333/500] ---
Loss -> Train: 2.379273, Val: 3.171383 | LR: 1.0e-08
  R² (Train) -> MMSE: 0.1404; CDRSB: 0.1267; LDELTOTAL: 0.2762; 
  R² (Val)   -> MMSE: 0.0943; CDRSB: 0.0343; LDELTOTAL: 0.2802; 
--- Epoch [334/500] ---
Loss -> Train: 2.413634, Val: 3.169448 | LR: 1.0e-08
  R² (Train) -> MMSE: 0.1305; CDRSB: 0.0942; LDELTOTAL: 0.2762; 
  R² (Val)   -> MMSE: 0.0869; CDRSB: 0.0355; LDELTOTAL: 0.2873; 
--- Epoch [335/500] ---
Loss -> Train: 2.3