In [None]:
"""
Description:
Author: voicebeer
Date: 2020-09-14
LastEditDate: 2025-04-20
"""

In [None]:
import os
import random
import sys
from typing import Any

import numpy as np
import torch
from tqdm import notebook as tqdm
from torch import nn, optim, utils
from torch.nn import functional as F

if "google.colab" in sys.modules:
    colab_cwd = "/content"  # Default Colab working directory
    repo_name = "DS-AGC-Colab"
    repo_url = f"https://github.com/theJingqiZhou/{repo_name}.git"

    print(f"Google Colab detected. Cloning {repo_name} repository...")

    # Clone the repo and move necessary files to the working directory
    os.system(
        f"""
        git clone {repo_url} &&
        mv {colab_cwd}/{repo_name}/modules {colab_cwd}/ &&
        mv {colab_cwd}/{repo_name}/datasets.py {colab_cwd}/ &&
        rm -rf {colab_cwd}/{repo_name}
    """
    )

    print(f"Setup complete. Required modules and datasets.py are now in {colab_cwd}")

import datasets, modules

In [None]:
RANDOM_SEED = 20

DATASET_NAME = "SEED"

DATA_DIR = os.path.join("data", DATASET_NAME)

In [None]:
def setup_seed(seed: int = RANDOM_SEED) -> None:
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    np.random.seed(seed)  # Numpy module.
    random.seed(seed)  # Python random module.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def weight_init(m: nn.Module) -> None:
    if isinstance(m, nn.Conv2d):
        setup_seed()
        nn.init.xavier_uniform_(
            m.weight.data
        )  # 对参数进行xavier初始化，为了使得网络中信息更好的流动，每一层输出的方差应该尽量相等
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.3)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.BatchNorm1d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        setup_seed()
        m.weight.data.normal_(0, 0.03)
        # nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in', nonlinearity='relu')
        if m.bias is not None:
            m.bias.data.zero_()


def get_cos_similarity_distance(
    pseudo: torch.Tensor, pred: torch.Tensor
) -> torch.Tensor:
    """Get distance in cosine similarity
    :param features: features of samples, (batch_size, num_clusters)
    :return: distance matrix between features, (batch_size, batch_size)
    """
    pseudo_norm = torch.norm(pseudo, dim=1, keepdim=True)
    pseudo = pseudo / pseudo_norm

    pred_norm = torch.norm(pred, dim=1, keepdim=True)
    pred = pred / pred_norm

    cos_dist_matrix = torch.mm(pseudo, pred.transpose(0, 1))
    return cos_dist_matrix

In [None]:
def test_gcn_contrast(
    model: modules.SemiGCL, target_loader: utils.data.DataLoader, device: str
) -> float:
    model.eval()
    assert isinstance(target_loader.dataset, utils.data.TensorDataset)
    data_target = target_loader.dataset.tensors[0].to(device)
    labels_target = target_loader.dataset.tensors[1].to(device)
    pred = model.predict(data_target)
    target_scores = pred.detach().argmax(dim=1)
    target_acc = (
        (target_scores == labels_target.argmax(dim=1)).float().sum().item()
    ) / len(data_target)
    print("target_acc:", target_acc)
    return target_acc


def train_gcn_contrast(
    subject_id: int,
    parameter: dict[str, Any],
    net_params: dict[str, Any],
    source_labeled_loaders: utils.data.DataLoader,
    source_unlabeled_loaders: utils.data.DataLoader,
    target_loader: utils.data.DataLoader,
) -> tuple[np.ndarray, float]:
    device = net_params["DEVICE"]
    setup_seed()
    model = modules.SemiGCL(net_params).to(device)
    setup_seed()
    model.apply(weight_init)
    awl = modules.AutomaticWeightedLoss(4)
    optimizer = optim.RMSprop(
        [
            {
                "params": model.parameters(),
                "lr": parameter["init_lr"],
                "weight_decay": parameter["weight_decay"],
            },
            {"params": awl.parameters(), "weight_decay": 0},
        ]
    )
    best_acc, best_test_acc = 0.0, 0.0
    acc_list = np.zeros(parameter["epochs"])
    threshold = parameter["threshold"]
    for epoch in range(parameter["epochs"]):
        model.train()
        total_loss, total_num, target_bar = 0.0, 0, tqdm.tqdm(target_loader)
        source_acc_total, target_acc_total = 0, 0
        train_source_iter_labeled = enumerate(source_labeled_loaders)
        train_source_iter_unlabeled = enumerate(source_unlabeled_loaders)
        setup_seed()
        for data_target, label_target in target_bar:
            _, (data_source, labels_source) = next(train_source_iter_labeled)
            _, (x_un, _) = next(train_source_iter_unlabeled)
            x_un = x_un.to(device)
            data_source, labels_source = data_source.to(device), labels_source.to(
                device
            )
            data_target, labels_target = data_target.to(device), label_target.to(device)
            if parameter["T_DANN"]:
                tripleada = 0
            else:
                tripleada = 1
            if epoch >= threshold:
                pred, domain_loss, ajloss, contrastive_loss, sim_weight, L2 = model(
                    torch.cat((data_source, x_un, data_target)),
                    tripleada=tripleada,
                    threshold=1,
                )
            else:
                pred, domain_loss, ajloss, contrastive_loss, sim_weight, L2 = model(
                    torch.cat((data_source, data_target)),
                    tripleada=0,
                    threshold=0,
                )

            source_pred = pred[0 : len(data_source), :]
            target_pred = pred[-len(data_source) :, :]
            if epoch >= threshold:
                log_prob = F.log_softmax(sim_weight * source_pred, dim=1)
                # log_prob = F.log_softmax(source_pred, dim=1)
            else:
                log_prob = F.log_softmax(source_pred, dim=1)

            celoss = -torch.sum(log_prob * labels_source) / len(labels_source)
            loss = (
                celoss
                + parameter["DANN"] * domain_loss
                + parameter["dynamic_adj"] * ajloss
                + parameter["GCL"] * contrastive_loss
            )

            source_scores = source_pred.detach().argmax(dim=1)
            source_acc = (
                (source_scores == labels_source.argmax(dim=1)).float().sum().item()
            )
            source_acc_total += source_acc
            target_scores = target_pred.detach().argmax(dim=1)
            target_acc = (
                (target_scores == labels_target.argmax(dim=1)).float().sum().item()
            )
            target_acc_total += target_acc
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_num += parameter["batch_size"]
            total_loss += loss.item() * parameter["batch_size"]
            epoch_train_loss = total_loss / total_num
            target_bar.set_description(
                f"sub:{subject_id} "
                f'Train Epoch: [{epoch + 1}/{parameter["epochs"]}] '
                f"Loss: {epoch_train_loss:.4f} "
                f"source_acc:{source_acc_total / total_num * 100:.2f}% "
                f"target_acc:{target_acc_total / total_num * 100:.2f}%"
            )
        target_test_acc = test_gcn_contrast(model, target_loader, device)
        acc_list[epoch] = target_test_acc
        if best_acc < (target_acc_total / total_num):
            best_acc = target_acc_total / total_num

        if best_test_acc < target_test_acc:
            best_test_acc = target_test_acc
            save_dir = os.path.join(DATA_DIR, "model_result")
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(
                save_dir,
                f"model_semi_session_{DATASET_NAME}_batch48_{parameter['threshold']}"
                f"epoch_{parameter['num_of_U']}U_sub{subject_id}.pkl"
            )
            torch.save(model.state_dict(), save_path)
    print("best_acc:", best_acc, "best_test_acc:", best_test_acc)
    return acc_list, best_test_acc

In [None]:
def cross_subject(
    target_set: dict[str, np.ndarray],
    source_set_labeled: dict[str, np.ndarray],
    source_set_unlabeled: dict[str, np.ndarray],
    subject_id: int,
    parameter: dict[str, int],
    net_params: dict[str, Any],
) -> tuple[np.ndarray, float]:
    setup_seed()
    torch_dataset_test = utils.data.TensorDataset(
        torch.from_numpy(target_set["feature"]),
        torch.from_numpy(target_set["label"]),
    )
    torch_dataset_source_labeled = utils.data.TensorDataset(
        torch.from_numpy(source_set_labeled["feature"]),
        torch.from_numpy(source_set_labeled["label"]),
    )
    torch_dataset_source_unlabeled = utils.data.TensorDataset(
        torch.from_numpy(source_set_unlabeled["feature"]),
        torch.from_numpy(source_set_unlabeled["label"]),
    )

    source_labeled_loaders = utils.data.DataLoader(
        dataset=torch_dataset_source_labeled,
        batch_size=parameter["batch_size"],
        shuffle=True,
        drop_last=True,
    )

    source_unlabeled_loaders = utils.data.DataLoader(
        dataset=torch_dataset_source_unlabeled,
        batch_size=parameter["batch_size"],
        shuffle=True,
        drop_last=True,
    )

    target_loader = utils.data.DataLoader(
        dataset=torch_dataset_test,
        batch_size=parameter["batch_size"],
        shuffle=True,
        drop_last=True,
    )

    acc = train_gcn_contrast(
        subject_id,
        parameter,
        net_params,
        source_labeled_loaders,
        source_unlabeled_loaders,
        target_loader,
    )
    return acc

In [None]:
def main(parameter: dict[str, int], net_params: dict[str, Any]) -> list[float]:
    # data preparation
    setup_seed()
    match net_params["category_number"]:
        case 3:
            print("Model name: MS-MDAER. Dataset name: SEED")
            sub_num = 15
        case 4:
            print("Model name: MS-MDAER. Dataset name: SEED_IV")
            sub_num = 15
        case 5:
            print("Model name: MS-MDAER. Dataset name: SEED_V")
            sub_num = 16
        case _:
            pass

    print(f'BS: {parameter["batch_size"]}, epoch: {parameter["epochs"]}')
    # store the results
    csub: list[float] = []
    # for session_id_main in range(3):
    session_id = 0
    # subject_id = 0

    best_acc_mat = np.zeros(sub_num)
    target_acc_curve = np.zeros((sub_num, parameter["epochs"]))

    for subject_id in range(sub_num):
        match net_params["category_number"]:
            case 3:
                target_set, source_set_labeled, source_set_unlabeled = (
                    datasets.seed.load_dataset(
                        DATA_DIR, subject_id, session_id, parameter
                    )
                )
                DATASET_NAME = "SEED"
            case 4:
                target_set, source_set_labeled, source_set_unlabeled = (
                    datasets.seed_iv.load_dataset(DATA_DIR, subject_id, parameter)
                )
                DATASET_NAME = "SEEDIV"
            case 5:
                target_set, source_set_labeled, source_set_unlabeled = (
                    datasets.seed_v.load_dataset(DATA_DIR, subject_id, parameter)
                )
                DATASET_NAME = "SEEDV"
            case _:
                pass

        acc = cross_subject(
            target_set,
            source_set_labeled,
            source_set_unlabeled,
            subject_id,
            parameter,
            net_params,
        )
        csub.append(acc[1])
        target_acc_curve[subject_id, :] = acc[0]
        best_acc_mat[subject_id] = acc[1]
    print("Cross-subject: ", csub)

    result_list = {
        "best_acc_mat": best_acc_mat,
        "target_acc_curve": target_acc_curve,
    }
    os.makedirs(os.path.join(DATA_DIR, "model_result"), exist_ok=True)
    np.save(
        os.path.join(DATA_DIR, "model_result",
        f'result_list_semi_session_{DATASET_NAME}_batch48_{parameter["threshold"]}'
        f'epoch_{parameter["num_of_U"]}U.npy'),
        result_list,  # type: ignore[reportArgumentType]
    )
    return csub


parameter = {
    "epochs": 100,
    "init_lr": 1e-3,
    "weight_decay": 1e-5,
    "semi": 1,
    "threshold": 30,
    "num_of_U": 2,
    "GCL": 1,
    "dynamic_adj": 1,
    "DANN": 1,
    "T_DANN": 1,
    "batch_size": 48,
}

net_params = {
    "category_number": 3,
    "DEVICE": (
        torch.accelerator.current_accelerator().type
        if torch.accelerator.is_available()
        else "cpu"
    ),
    "node_feature_hidden1": 5,
    "num_of_vertices": 62,
    "linearsize": 128,
    "drop_rate": 0.8,
    "batch_size": 48,  # Also used by class `modules.models.SemiGCL`
    "Multi_att": 1,
    "num_of_features": 5,
    "GLalpha": 0.01,
    "K": 3,
}

csub = main(parameter, net_params)

c = np.load(
    f'result_list_semi_session_{DATASET_NAME}_batch48_{parameter["threshold"]}'
    f'epoch_{parameter["num_of_U"]}U.npy',
    allow_pickle=True,
).item()
c_mean = np.mean(c["best_acc_mat"])
c_std = np.std(c["best_acc_mat"])