# 在 Google Colab 运行本项目（自动环境 & 数据获取）

本笔记本会在全新 Colab 环境中完成：
- 安装所需依赖（numpy, matplotlib, PyTorch 在 Colab 已自带）。
- 可选挂载 Google Drive，从指定目录拷贝数据文件：modes_lp_10.npz、gauss_5x2_custom.npz。
- 也可配置 GitHub 原始链接前缀，自动下载上述数据文件。
- 将当前仓库中的 utils.py 与 MPLC_CUDA2.py 写入工作目录。
- 提供一个快速“烟雾测试运行”（默认迭代较少，验证环境就绪），以及可选的“完整运行”。

请先在下方“环境与依赖安装”与“数据来源配置”两节确认设置。

# 在 Google Colab 运行 MPLC_CUDA2（多波长）

本 Notebook 会：
- 安装所需依赖（最少化变更，优先使用 Colab 预装 PyTorch/CUDA）
- 挂载 Google Drive 或从 GitHub 下载数据文件（modes_lp_10.npz、gauss_5x2_custom.npz）
- 将项目核心脚本与工具写入当前运行环境（utils.py、MPLC_CUDA2.py）
- 一键运行并展示结果图片（保存在 results/）

如果你没有公共 GitHub 链接，请把数据文件放到 Drive 指定目录，或使用“手动上传”单元格上传。

In [None]:
# 环境与依赖安装（幂等）
import sys, subprocess, os, json

# Colab 里一般预装 torch 和 CUDA，可以按需升级/锁版本。
req = [
    'numpy',
    'matplotlib',
]

# 仅当缺失时安装，避免重复网络耗时
for pkg in req:
    try:
        __import__(pkg)
    except Exception:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', pkg])

# 打印关键包与 CUDA 情况
import numpy as np, matplotlib
try:
    import torch
    print('[Env] torch:', torch.__version__, 'CUDA available:', torch.cuda.is_available())
    if torch.cuda.is_available():
        print('[Env] CUDA device count:', torch.cuda.device_count())
except Exception as e:
    print('[Warn] torch import failed:', e)

os.makedirs('results', exist_ok=True)
print('[OK] Base deps ready.')

In [None]:
# 环境与依赖安装（Colab）
# - Colab 通常自带 torch/cuda、numpy、matplotlib。
# - 若需要固定版本，可在下方解开注释安装。
# - 自动检测 CUDA 可用性。

import os, sys, subprocess, textwrap, json, shutil

print("Python:", sys.version)

try:
    import torch, numpy as np, matplotlib
    print("torch:", torch.__version__, "cuda:", torch.cuda.is_available())
except Exception as e:
    print("Import error:", e)

# 可选：固定第三方版本（按需）
# !pip -q install numpy==1.26.4 matplotlib==3.7.1

# ========== 数据来源配置 ==========
# 选项一：从 Google Drive 复制（推荐你先把 npz 放到 Drive 的某个目录）
USE_DRIVE = True  # 如果不用 Drive，设为 False
DRIVE_DIR = "/content/drive/MyDrive/mplc_data"  # 你的 npz 所在目录

# 选项二：从 GitHub 原始链接下载（把 raw 前缀改成你的仓库文件原始链接前缀）
USE_GITHUB = False
GITHUB_RAW_PREFIX = "https://raw.githubusercontent.com/<your_user>/<your_repo>/<branch_or_tag>/"  # 修改为你自己的

DATA_FILES = [
    ("modes_lp_10.npz", "modes_lp_10.npz"),
    ("gauss_5x2_custom.npz", "gauss_5x2_custom.npz"),
]

# ========== 可选：挂载 Google Drive ==========
if USE_DRIVE:
    try:
        from google.colab import drive  # type: ignore
        drive.mount('/content/drive')
        print("Drive mounted")
    except Exception as e:
        print("Drive not available (non-Colab or no permission):", e)
        USE_DRIVE = False

os.makedirs('/content/work', exist_ok=True)
%cd /content/work

# ========== 获取数据文件 ==========
missing = []
for local_name, drive_name in DATA_FILES:
    if os.path.exists(local_name):
        continue
    copied = False
    if USE_DRIVE:
        src = os.path.join(DRIVE_DIR, drive_name)
        if os.path.exists(src):
            shutil.copy2(src, local_name)
            print(f"Copied from Drive: {src} -> {local_name}")
            copied = True
        else:
            print(f"Not found in Drive: {src}")
    if (not copied) and USE_GITHUB:
        import urllib.request
        url = GITHUB_RAW_PREFIX.rstrip('/') + '/' + drive_name
        try:
            print("Downloading:", url)
            urllib.request.urlretrieve(url, local_name)
            copied = True
        except Exception as e:
            print("Download failed:", e)
    if not copied:
        missing.append(local_name)

if missing:
    print("WARNING: 下列数据文件未找到。请：")
    print("1) 打开上方 USE_DRIVE=True 并把文件放到 DRIVE_DIR；或")
    print("2) 打开 USE_GITHUB=True 并配置 GITHUB_RAW_PREFIX；或")
    print("3) 手动上传到左侧 Files 面板当前目录(/content/work)。")
    print("Missing:", missing)
else:
    print("All data files ready.")

# ========== 写入 utils.py 与 MPLC_CUDA2.py ==========
# 我们将把本仓库中的两个关键文件内容嵌入到 Colab 工作目录，避免 import 失败。

UTILS_CODE = r"""
from typing import Union
import numpy as np
import torch
from functools import singledispatch
from math import pi
from matplotlib import pyplot as plt

__all__ = [
    "fft2",
    "ifft2",
    "normalize",
    "propagate_HK",
    "fidelity",
    "loc_fidelity",
    "performance_loc_fidelity",
    "performance_efficiency",
    "performance_crosstalk",
    "complim",
    "complim_subplot2",
    "plot_in_GS",
]

@singledispatch
def fft2(x: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
    raise NotImplementedError(f"Cannot fourier transform `x` for type: {type(x)}")

@fft2.register
def _(x: np.ndarray) -> np.ndarray:
    return np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(x, axes=(-1, -2)), norm="ortho"), axes=(-1, -2))

@fft2.register
def fft2_torch(x: torch.Tensor) -> torch.Tensor:
    return torch.fft.fftshift(torch.fft.fft2(torch.fft.ifftshift(x, dim=(-1, -2)), norm="ortho"), dim=(-1, -2))

@singledispatch
def ifft2(x: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
    raise NotImplementedError(f"Cannot Inverse fourier transform `x` for {type(x)}")

@ifft2.register
def _(x: np.ndarray) -> np.ndarray:
    return np.fft.fftshift(np.fft.ifft2(np.fft.ifftshift(x, axes=(-1, -2)), norm="ortho"), axes=(-1, -2))

@ifft2.register
def ifft2_torch(x: torch.Tensor) -> torch.Tensor:
    return torch.fft.fftshift(torch.fft.ifft2(torch.fft.ifftshift(x, dim=(-1, -2)), norm="ortho"), dim=(-1, -2))

@singledispatch
def normalize(x: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
    raise NotImplementedError(f"Cannot normalize for {type(x)}")

@normalize.register
def _(x: torch.Tensor) -> torch.Tensor:
    return x / torch.linalg.norm(x)

@normalize.register
def _(x: np.ndarray) -> np.ndarray:
    return x / np.linalg.norm(x)

@singledispatch
def propagate_HK(FieldIn: Union[np.ndarray, torch.Tensor], kz: Union[np.ndarray, torch.Tensor], distance: float = 0.0) -> Union[np.ndarray, torch.Tensor]:
    raise NotImplementedError(f"Cannot process `FieldIn` type: {type(FieldIn)}")

@propagate_HK.register
def _(FieldIn: np.ndarray, kz: np.ndarray, distance: float = 0.0) -> np.ndarray:
    FieldIn_FT = fft2(FieldIn)
    FieldOut_FT = FieldIn_FT*np.exp(1j*kz*distance)*(np.imag(kz)==0)
    FieldOut = ifft2(FieldOut_FT)
    return FieldOut

@propagate_HK.register
def _(FieldIn: torch.Tensor, kz: torch.Tensor, distance: float = 0.0) -> torch.Tensor:
    FieldIn_FT = fft2(FieldIn)
    FieldOut_FT = FieldIn_FT*torch.exp(1j*kz*distance)*(torch.imag(kz)==0)
    FieldOut = ifft2(FieldOut_FT)
    return FieldOut

@singledispatch
def fidelity(a: Union[np.ndarray, torch.Tensor], b: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
    raise NotImplementedError(f"Cannot check fidelity of `a`, `b` for {type(a)}, {type(b)}")

@fidelity.register
def _(a: np.ndarray, b: np.ndarray) -> float:
    return np.square(np.abs(np.sum(normalize(a).conj() * normalize(b))))

@fidelity.register
def _(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return torch.square(torch.abs(torch.sum(normalize(a).conj() * normalize(b))))

@singledispatch
def loc_fidelity(a: Union[np.ndarray, torch.Tensor], channel: Union[np.ndarray, torch.Tensor], b: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
    raise NotImplementedError(f"Cannot check fidelity of `a`, `b` for {type(a)}, {type(b)}")

@loc_fidelity.register
def _(a: np.ndarray, channel: np.ndarray, b: np.ndarray) -> float:
    a = a*channel
    return np.square(np.abs(np.sum(normalize(a).conj() * normalize(b))))

@loc_fidelity.register
def _(a: torch.Tensor, channel: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    a = a*channel
    return torch.square(torch.abs(torch.sum(normalize(a).conj() * normalize(b))))

@singledispatch
def performance_loc_fidelity(A: Union[np.ndarray, torch.Tensor], channels: Union[np.ndarray, torch.Tensor], B: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
    raise NotImplementedError(f"Cannot check fidelity of `A`, `B` for {type(A)}, {type(B)}")

@performance_loc_fidelity.register
def _(A: np.ndarray, channels: np.ndarray, B: np.ndarray) -> Union[np.ndarray, float]:
    A = np.squeeze(A)
    B = np.squeeze(B)
    CH = np.squeeze(channels)
    fid_list = np.zeros((A.shape[0]))
    for i in range(0, A.shape[0]):
        fid_list[i] = loc_fidelity(A[i,:,:], CH[i,:,:], B[i,:,:])
    av_loc_fid = 100*np.sum(fid_list)/A.shape[0]
    return av_loc_fid, fid_list

@performance_loc_fidelity.register
def _(A: torch.Tensor, channels: torch.Tensor, B: torch.Tensor) -> Union[torch.Tensor, float]:
    A = torch.squeeze(A)
    B = torch.squeeze(B)
    CH = torch.squeeze(channels)
    fid_list = torch.zeros((A.shape[0]))
    for i in range(0, A.shape[0]):
        fid_list[i] = loc_fidelity(A[i,:,:], CH[i,:,:], B[i,:,:])
    av_loc_fid = 100*torch.sum(fid_list)/A.shape[0]
    return av_loc_fid, fid_list

@singledispatch
def performance_efficiency(A: Union[np.ndarray, torch.Tensor], channels: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
    raise NotImplementedError(f"Cannot check efficiency of `A` for {type(A)}")

@performance_efficiency.register
def _(A: np.ndarray, channels: np.ndarray) -> Union[np.ndarray, float]:
    A = np.squeeze(A)
    CH = np.squeeze(channels)
    eff_list = np.zeros((A.shape[0]))
    for i in range(0, A.shape[0]):
        eff_list[i] = np.sum(A[i,:,:]*CH[i,:,:])
    av_eff = 100*np.sum(eff_list)/A.shape[0]
    return av_eff, eff_list

@performance_efficiency.register
def _(A: torch.Tensor, channels: torch.Tensor) -> Union[torch.Tensor, float]:
    A = torch.squeeze(A)
    CH = torch.squeeze(channels)
    eff_list = torch.zeros((A.shape[0]))
    for i in range(0, A.shape[0]):
        eff_list[i] = torch.sum(A[i,:,:]*CH[i,:,:])
    av_eff = 100*torch.sum(eff_list)/A.shape[0]
    return av_eff, eff_list

@singledispatch
def performance_crosstalk(A: Union[np.ndarray, torch.Tensor], channels: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
    raise NotImplementedError(f"Cannot check cross-talk for {type(A)}")

@performance_crosstalk.register
def _(A: np.ndarray, channels: np.ndarray) -> Union[np.ndarray, np.ndarray, float]:
    A = np.squeeze(A)
    CH = np.squeeze(channels)
    crs_list = np.zeros((A.shape[0]))
    crs_matrix = np.zeros((A.shape[0],A.shape[0]))
    for i in range(0, A.shape[0]):
        for j in range(0, A.shape[0]):
            crs_matrix[i,j] = np.sum(A[j,:,:]*CH[i,:,:])
    for i in range(0, A.shape[0]): 
        crs_list[i] = 1 - (crs_matrix[i,i]/np.sum(crs_matrix[:,i]))
    av_crs = 100*np.sum(crs_list)/A.shape[0]
    return av_crs, crs_list, crs_matrix

@performance_crosstalk.register
def _(A: torch.Tensor, channels: torch.Tensor) -> Union[torch.Tensor, torch.Tensor, float]:
    A = torch.squeeze(A)
    CH = torch.squeeze(channels)
    crs_list = torch.zeros((A.shape[0]))
    crs_matrix = torch.zeros((A.shape[0],A.shape[0]))
    for i in range(0, A.shape[0]):
        for j in range(0, A.shape[0]):
            crs_matrix[i,j] = torch.sum(A[j,:,:]*CH[i,:,:])
    for i in range(0, A.shape[0]): 
        crs_list[i] = 1 - (crs_matrix[i,i]/torch.sum(crs_matrix[:,i]))
    av_crs = 100*torch.sum(crs_list)/A.shape[0]
    return av_crs, crs_list, crs_matrix

@singledispatch
def complim(x: Union[np.ndarray, torch.Tensor]):
    raise NotImplementedError(f"Cannot visualize `x` for type: {type(x)}")

@complim.register
def _(x: np.ndarray) -> np.ndarray:
    mAx = np.amax(np.abs(x))
    M = x/mAx
    A = np.abs(M)
    P = np.angle(M)
    A[A > 1.] = 1.
    R = A*(np.cos(P - 2*pi/3)/2+0.5)
    G = A*(np.cos(P)/2+0.5)
    B = A*(np.cos(P + 2*pi/3)/2+0.5)
    C = np.dstack((R, G, B))
    plt.imshow(C)
    plt.show()

@complim.register
def _(x: torch.Tensor) -> torch.Tensor:
    mAx = torch.amax(torch.abs(x))
    M = x/mAx
    A = torch.abs(M)
    P = torch.angle(M)
    A[A > 1.] = 1.
    R = A*(torch.cos(P - 2*pi/3)/2+0.5)
    G = A*(torch.cos(P)/2+0.5)
    B = A*(torch.cos(P + 2*pi/3)/2+0.5)
    C = torch.dstack((R, G, B))
    plt.imshow(C)
    plt.show()

@singledispatch
def plot_in_GS(x: Union[np.ndarray, torch.Tensor]):
    raise NotImplementedError(f"Cannot visualize `x` for type: {type(x)}")

@plot_in_GS.register
def _(x: np.ndarray) -> np.ndarray:
    x = np.angle(np.exp(1j*x))
    plt.imshow(x, cmap="gray")
    plt.show()

@plot_in_GS.register
def _(x: torch.Tensor) -> torch.Tensor:
    x = torch.angle(torch.exp(1j*x))
    plt.imshow(x, cmap="gray")
    plt.show()

@singledispatch
def complim_subplot2(x: Union[np.ndarray, torch.Tensor]):
    raise NotImplementedError(f"Cannot visualize `x` for type: {type(x)}")

@complim_subplot2.register
def _(x: np.ndarray, y: np.ndarray, titles: list) -> np.ndarray:
    mAx = np.amax(np.abs(x))
    M = x/mAx
    A = np.abs(M)
    P = np.angle(M)
    A[A > 1.] = 1.
    R = A*(np.cos(P - 2*pi/3)/2+0.5)
    G = A*(np.cos(P)/2+0.5)
    B = A*(np.cos(P + 2*pi/3)/2+0.5)
    C1 = np.dstack((R, G, B))
    mAx = np.amax(np.abs(y))
    M = y/mAx
    A = np.abs(M)
    P = np.angle(M)
    A[A > 1.] = 1.
    R = A*(np.cos(P - 2*pi/3)/2+0.5)
    G = A*(np.cos(P)/2+0.5)
    B = A*(np.cos(P + 2*pi/3)/2+0.5)
    C2 = np.dstack((R, G, B))
    C = [C1, C2]
    fig, axs = plt.subplots(1, 2)
    i = 0
    for ax, interp in zip(axs, titles):
        ax.imshow(C[i])
        ax.set_title(interp, fontsize=10)
        i = i+1
    plt.show()

@complim_subplot2.register
def _(x: torch.Tensor, y: torch.Tensor, titles: list) -> torch.Tensor:
    mAx = torch.amax(torch.abs(x))
    M = x/mAx
    A = torch.abs(M)
    P = torch.angle(M)
    A[A > 1.] = 1.
    R = A*(torch.cos(P - 2*pi/3)/2+0.5)
    G = A*(torch.cos(P)/2+0.5)
    B = A*(torch.cos(P + 2*pi/3)/2+0.5)
    C1 = torch.dstack((R, G, B))
    mAx = torch.amax(torch.abs(y))
    M = y/mAx
    A = torch.abs(M)
    P = torch.angle(M)
    A[A > 1.] = 1.
    R = A*(torch.cos(P - 2*pi/3)/2+0.5)
    G = A*(torch.cos(P)/2+0.5)
    B = A*(torch.cos(P + 2*pi/3)/2+0.5)
    C2 = torch.dstack((R, G, B))
    C = [C1, C2]
    fig, axs = plt.subplots(1, 2)
    i = 0
    for ax, interp in zip(axs, titles):
        ax.imshow(C[i])
        ax.set_title(interp, fontsize=10)
        i = i+1
    plt.show()
"""

open("utils.py", "w", encoding="utf-8").write(UTILS_CODE)
print("Wrote utils.py")

# 把 MPLC_CUDA2.py 写入（从仓库拷贝的版本，去掉 argparse 的 CLI 退出副作用以支持 notebook 运行）
MPLC_CODE = r"""
import torch
import numpy as np
import torch.nn as nn
import math
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import argparse
from utils import *

DEFAULTS = {
    "n_of_modes": 10,
    "Planes": 7,
    "iterations": 60,   # 默认降低迭代，便于烟雾运行
    "alpha": 1.0,
    "beta": 2.0,
    "gamma": 0.0,
    "first_n_iterations": 10,
    "delta_theta_1": 2*math.pi/255,
    "delta_theta_0": 10*(2*math.pi/255),
    "Nx": 512,
    "Ny": 512,
    "pixelSize": 8e-6,
    "wavelength": 1.57e-6,
    "d_in": 20e-3,
    "d": 2*9.7e-3,
    "d_out": 15e-3,
    "calc_perf_every_it": 10,
    "equalize_efficiency": 1,
    "plot_eff_distribution": 0,
    "smoothing_switch": 1,
    "OffsetMultiplier": 0e-5,
    "plot_results": 0,
    "do_padded_eval": 0,
}

def parse_cfg() -> dict:
    parser = argparse.ArgumentParser(add_help=True)
    parser.add_argument("--n_of_modes", type=int, default=None)
    parser.add_argument("--Planes", type=int, default=None)
    parser.add_argument("--iterations", type=int, default=None)
    parser.add_argument("--first_n_iterations", type=int, default=None)
    parser.add_argument("--Nx", type=int, default=None)
    parser.add_argument("--Ny", type=int, default=None)
    parser.add_argument("--calc_perf_every_it", type=int, default=None)
    parser.add_argument("--equalize_efficiency", type=int, choices=[0,1], default=None)
    parser.add_argument("--plot_eff_distribution", type=int, choices=[0,1], default=None)
    parser.add_argument("--smoothing_switch", type=int, choices=[0,1], default=None)
    parser.add_argument("--plot_results", type=int, choices=[0,1], default=None)
    parser.add_argument("--do_padded_eval", type=int, choices=[0,1], default=None)
    parser.add_argument("--alpha", type=float, default=None)
    parser.add_argument("--beta", type=float, default=None)
    parser.add_argument("--gamma", type=float, default=None)
    parser.add_argument("--delta_theta_1", type=float, default=None)
    parser.add_argument("--delta_theta_0", type=float, default=None)
    parser.add_argument("--pixelSize", type=float, default=None)
    parser.add_argument("--wavelength", type=float, default=None)
    parser.add_argument("--d_in", type=float, default=None)
    parser.add_argument("--d", type=float, default=None)
    parser.add_argument("--d_out", type=float, default=None)
    parser.add_argument("--OffsetMultiplier", type=float, default=None)

    try:
        args = parser.parse_args([])  # notebook 环境避免解析 sys.argv
    except SystemExit:
        args = argparse.Namespace()
    cfg = DEFAULTS.copy()
    for k, v in vars(args).items() if hasattr(args, "__dict__") else []:
        if v is not None:
            cfg[k] = v
    return cfg

CFG = parse_cfg()
(n_of_modes, Planes, iterations,
 alpha, beta, gamma,
 first_n_iterations, delta_theta_1, delta_theta_0,
 Nx, Ny, pixelSize, wavelength,
 d_in, d, d_out,
 calc_perf_every_it,
 equalize_efficiency, plot_eff_distribution, smoothing_switch, OffsetMultiplier) = (
     CFG["n_of_modes"], CFG["Planes"], CFG["iterations"],
     CFG["alpha"], CFG["beta"], CFG["gamma"],
     CFG["first_n_iterations"], CFG["delta_theta_1"], CFG["delta_theta_0"],
     CFG["Nx"], CFG["Ny"], CFG["pixelSize"], CFG["wavelength"],
     CFG["d_in"], CFG["d"], CFG["d_out"],
     CFG["calc_perf_every_it"],
     CFG["equalize_efficiency"], CFG["plot_eff_distribution"], CFG["smoothing_switch"], CFG["OffsetMultiplier"])

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"[MPLC2] Using device: {DEVICE}")

reprW, reprH = Nx * pixelSize, Ny * pixelSize
crs_delta = 0.0001 * calc_perf_every_it
maskOffset = OffsetMultiplier * np.sqrt(1e-3 / (Nx * Ny * n_of_modes))

nx_m = pixelSize*np.linspace(-(Nx-1)/2, (Nx-1)/2, num=Nx)
ny_m = pixelSize*np.linspace(-(Ny-1)/2, (Ny-1)/2, num=Ny)
X,Y = np.meshgrid(nx_m,ny_m)
X_torch = torch.from_numpy(X).to(DEVICE)
Y_torch = torch.from_numpy(Y).to(DEVICE)

nx = np.linspace(-(Nx-1)/2, (Nx-1)/2, num=Nx)
ny = np.linspace(-(Ny-1)/2, (Ny-1)/2, num=Ny)
kx, ky = np.meshgrid(2*np.pi*nx/(Nx*pixelSize),2*np.pi*ny/(Ny*pixelSize))

lambda_list = np.array([1.53e-6, 1.55e-6, 1.57e-6, 1.59e-6, 1.61e-6, 1.625e-6], dtype=np.float64)
lambda_c = 1.57e-6

lp_data = np.load('modes_lp_10.npz')
lp_modes = lp_data['profiles']
gauss_data = np.load('gauss_5x2_custom.npz')
gauss_modes = gauss_data['profiles']

L = min(lp_modes.shape[0], gauss_modes.shape[0], len(lambda_list))
lambda_list = lambda_list[:L]

Speckle_basis = lp_modes[:L, 0:n_of_modes, :, :]
Gaussian_basis = gauss_modes[:L, 0:n_of_modes, :, :]
Speckle_basis_torch = torch.from_numpy(Speckle_basis).to(torch.cdouble).to(DEVICE)
Gaussian_basis_torch = torch.from_numpy(Gaussian_basis).to(torch.cdouble).to(DEVICE)

Gaussian_Masks = np.zeros_like(Gaussian_basis, dtype=np.float64)
for l in range(L):
    for m in range(n_of_modes):
        inten = np.abs(Gaussian_basis[l, m, :, :]) ** 2
        thr = 0.05 * np.max(inten)
        Gaussian_Masks[l, m, :, :] = inten > thr
Gaussian_Masks_torch = torch.from_numpy(Gaussian_Masks).to(torch.double).to(DEVICE)

if (Nx > 512) or (Ny > 512):
    pad_x = int((Nx-512)/2)
    pad_y = int((Ny-512)/2)
    Speckle_basis_torch = nn.functional.pad(Speckle_basis_torch, (pad_x, Nx-512-pad_x, pad_y, Ny-512-pad_y), mode='constant', value=0.+0.j)
    Gaussian_basis_torch = nn.functional.pad(Gaussian_basis_torch, (pad_x, Nx-512-pad_x, pad_y, Ny-512-pad_y), mode='constant', value=0.+0.j)
    Gaussian_Masks_torch = nn.functional.pad(Gaussian_Masks_torch, (pad_x, Nx-512-pad_x, pad_y, Ny-512-pad_y), mode='constant', value=0.0)

phi_bk = torch.ones((Gaussian_Masks_torch.shape[0], Ny, Nx), dtype=torch.double, device=DEVICE) - torch.sum(Gaussian_Masks_torch, axis = 1)
phi_cr = torch.zeros((Gaussian_Masks_torch.shape[0], n_of_modes, Ny, Nx), dtype = torch.double, device=DEVICE)
for l in range(Gaussian_Masks_torch.shape[0]):
    for i in range(n_of_modes):
        phi_cr[l,i,:,:] = torch.sum(Gaussian_Masks_torch[l], axis = 0) - Gaussian_Masks_torch[l,i,:,:]

phi = Gaussian_basis_torch

Masks = torch.zeros((Planes,Ny,Nx), dtype=torch.double, device=DEVICE)
Masks_complex = torch.exp(1j*Masks)

L = Gaussian_Masks_torch.shape[0]
Modes_in = torch.zeros((L, Planes, n_of_modes, Ny, Nx), dtype = torch.cdouble, device=DEVICE)
Modes_out = torch.zeros((L, Planes, n_of_modes, Ny, Nx), dtype = torch.cdouble, device=DEVICE)

overlap = torch.zeros((n_of_modes), dtype = torch.cdouble, device=DEVICE)
eff_distribution = torch.ones((n_of_modes), dtype = torch.double, device=DEVICE)
dFdpsi = torch.zeros((L, Planes, n_of_modes, Ny, Nx), dtype = torch.cdouble, device=DEVICE)
crs_array_convergence = torch.zeros((max(1,iterations//max(1,calc_perf_every_it))), dtype = torch.double, device=DEVICE)
conv_count = 0

kz_torch_list = []
for l in range(L):
    k_l = (2*np.pi)/lambda_list[l]
    kz_l = np.sqrt(k_l**2 - (kx**2 + ky**2))
    kz_torch_list.append(torch.from_numpy(kz_l.astype(np.cdouble)).to(DEVICE))
    Modes_in[l, 0, :, :, :] = propagate_HK(Speckle_basis_torch[l], kz_torch_list[l], d_in)
    Modes_out[l, Planes-1, :, :, :] = propagate_HK(phi[l], kz_torch_list[l], -d_out)

for i in range(1, iterations+1):
    delta_theta = delta_theta_0 if i < first_n_iterations else delta_theta_1
    for mask_ind in range(Planes):
        for l in range(L):
            scale_l = lambda_c / lambda_list[l]
            modes = torch.zeros((n_of_modes, Ny, Nx), dtype = torch.cdouble, device=DEVICE)
            for pl in range(Planes-1):
                mask_cmplx_l = torch.exp(1j*(Masks[pl, :, :]*scale_l))
                modes = Modes_in[l, pl, :, :, :] * mask_cmplx_l
                modes = propagate_HK(modes, kz_torch_list[l], d)
                Modes_in[l, pl+1, :, :, :] = modes
            modes_forw_last_plane = Modes_in[l, Planes-1, :, :, :] * torch.exp(1j*(Masks[Planes-1, :, :]*scale_l))
            eout_l = propagate_HK(modes_forw_last_plane, kz_torch_list[l], d_out)
            for j in range(n_of_modes):
                overlap = torch.sum(torch.squeeze(eout_l[j,:,:]) * torch.conj(torch.squeeze(phi[l,j,:,:])))
                a = (phi[l, j, :, :]) * overlap
                psi_cr_l = (torch.squeeze(eout_l[j,:,:])) * torch.squeeze(phi_cr[l,j,:,:])
                psi_bk_l = (torch.squeeze(eout_l[j,:,:])) * phi_bk[l]
                dFdpsi[l, Planes-1, j, :, :] = - alpha*a + (beta*psi_cr_l - gamma*psi_bk_l)*0.5
            dFdpsi[l, Planes-1, :, :, :] = propagate_HK(dFdpsi[l, Planes-1, :, :, :], kz_torch_list[l], -d_out)
            for pl in range(Planes-1, mask_ind, -1):
                mask_cmplx_l = torch.exp(1j*(Masks[pl, :, :]*scale_l))
                dFdpsi_prop = dFdpsi[l, pl, :, :, :] * torch.conj(mask_cmplx_l)
                dFdpsi_prop = propagate_HK(dFdpsi_prop, kz_torch_list[l], -d)
                dFdpsi[l, pl-1, :, :, :] = dFdpsi_prop
                phi_prop = Modes_out[l, pl, :, :, :] * torch.conj(mask_cmplx_l)
                phi_prop = propagate_HK(phi_prop, kz_torch_list[l], -d)
                Modes_out[l, pl-1, :, :, :] = phi_prop
        if equalize_efficiency == 1:
            total_term = torch.zeros((Ny,Nx), dtype=torch.cdouble, device=DEVICE)
            for l in range(L):
                scale_l = lambda_c / lambda_list[l]
                mask_cmplx_l = torch.exp(1j*(Masks[mask_ind, :, :]*scale_l))
                weighted_overlaps = torch.zeros((Ny,Nx), dtype=torch.cdouble, device=DEVICE)
                for mode in range(n_of_modes):
                    weighted_overlaps = weighted_overlaps + (1/eff_distribution[mode]) * torch.squeeze(Modes_in[l, mask_ind, mode, :, :]) * torch.conj(torch.squeeze(dFdpsi[l, mask_ind, mode, :, :]))
                total_term = total_term + mask_cmplx_l * weighted_overlaps
            delta_P = delta_theta*torch.sign(torch.imag(total_term))
        else:
            total_term = torch.zeros((Ny,Nx), dtype=torch.cdouble, device=DEVICE)
            for l in range(L):
                scale_l = lambda_c / lambda_list[l]
                mask_cmplx_l = torch.exp(1j*(Masks[mask_ind, :, :]*scale_l))
                overlaps = torch.sum(torch.squeeze(Modes_in[l, mask_ind, :, :, :]) * torch.conj(torch.squeeze(dFdpsi[l, mask_ind, :, :, :])), axis = 0)
                total_term = total_term + mask_cmplx_l * overlaps
            delta_P = delta_theta*torch.sign(torch.imag(total_term))
        if smoothing_switch == 1:
            ov_sum = torch.zeros((Ny, Nx), dtype=torch.double, device=DEVICE)
            for l in range(L):
                ov_sum = ov_sum + torch.abs(torch.sum(torch.squeeze(Modes_in[l, mask_ind, :, :, :]*torch.conj(Modes_out[l, mask_ind, :, :, :])), axis = 0))
            ovrlp_in_out = ov_sum / L
            mask_cmplx = ovrlp_in_out*torch.exp(1j*(Masks[mask_ind, :, :] + delta_P)) 
            mask_cmplx = mask_cmplx + 0.0
            Masks[mask_ind, :, :] = torch.angle(mask_cmplx)
        else:
            Masks[mask_ind, :, :] = Masks[mask_ind, :, :] + delta_P
        Masks_complex[mask_ind, :, :] = torch.exp(1j*torch.squeeze(Masks[mask_ind, :, :]))

    if i % max(1,calc_perf_every_it) == 0:
        fids = []
        crss = []
        effs = []
        for l in range(L):
            scale_l = lambda_c / lambda_list[l]
            for pl in range(Planes-1):
                mask_cmplx_l = torch.exp(1j*(Masks[pl, :, :]*scale_l))
                modes = Modes_in[l, pl, :, :, :]*mask_cmplx_l
                modes = propagate_HK(modes, kz_torch_list[l], d)
                Modes_in[l, pl+1, :, :, :] = modes
            modes = modes*torch.exp(1j*(Masks[Planes-1, :, :]*scale_l))
            eout = propagate_HK(modes, kz_torch_list[l], d_out)
            eout_int_only = (torch.abs(eout))**2
            fid, _ = performance_loc_fidelity(eout, Gaussian_Masks_torch[l], phi[l]) 
            crs, _, _ = performance_crosstalk(eout_int_only, Gaussian_Masks_torch[l]) 
            eff, eff_list = performance_efficiency(eout_int_only, Gaussian_Masks_torch[l])
            fids.append(fid); crss.append(crs); effs.append(eff)
        fid = torch.stack(fids).mean(); crs = torch.stack(crss).mean(); eff = torch.stack(effs).mean()
        print('iteration', i, ': loc. fidelity =', round(fid.detach().cpu().numpy().item(),2), ', crosstalk =', round(crs.detach().cpu().numpy().item(),2), ', efficiency =', round(eff.detach().cpu().numpy().item(),2))

print("Training loop finished.")
"""

open("MPLC_CUDA2.py", "w", encoding="utf-8").write(MPLC_CODE)
print("Wrote MPLC_CUDA2.py")

print("Setup done. You can now run the smoke test below.")

In [None]:
# 烟雾测试运行（快速验证）：迭代较少、默认参数
# 需要确保 /content/work 下存在 modes_lp_10.npz 与 gauss_5x2_custom.npz

import os
assert os.path.exists('modes_lp_10.npz'), '缺少 modes_lp_10.npz'
assert os.path.exists('gauss_5x2_custom.npz'), '缺少 gauss_5x2_custom.npz'

import importlib
import MPLC_CUDA2 as runmod
importlib.reload(runmod)

# 直接导入即执行（脚本式）。若你想在此处覆盖默认参数，可修改上一个单元格中 DEFAULTS。
print('Smoke test done.')

## 可选：完整运行与参数

- 你可以在“环境与依赖安装”单元格中调整 `DEFAULTS`（如 iterations, n_of_modes, Planes 等）。
- 若需要保存可视化结果，会写入 `/content/work/results` 目录。

常见问题：
- 缺数据文件：请挂载 Drive 或配置 GitHub 原始地址，或手动上传到 `/content/work`。
- CUDA 不可用：Colab 免费版 GPU 需在菜单 Runtime -> Change runtime type -> GPU。
- 内存不足：降低 `n_of_modes`，减少 `iterations`，或关闭 `plot_results`。