In [None]:
# %%
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import transforms
import numpy as np
import copy
from torch.utils.data import DataLoader, Subset, Dataset
from sklearn.metrics import roc_auc_score, confusion_matrix, ConfusionMatrixDisplay
import random
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

# %%
def set_seed(seed=66):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

CONFIG = {
    "batch_size": 128,
    "closed_lr": 0.02,
    "local_epochs": 5,
    "num_clients": 5,
    "global_rounds": 50,
    "known_classes": 6,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "lambda_contrastive": 0.5,
    "lambda_trash": 1.0,
    "temp": 0.08
}

set_seed()
print(f"Running on device: {CONFIG['device']}")

# %%
#  [Dataset Logic]
class CIFAR10FedOSR(Dataset):
    def __init__(self, full_dataset, known_classes, is_open=False):
        self.data = []
        self.targets = []
        known_class_set = set(known_classes)
        for img, label in full_dataset:
            if not is_open:
                if label in known_class_set:
                    self.data.append(img)
                    self.targets.append(known_classes.index(label))
            else:
                if label not in known_class_set:
                    self.data.append(img)
                    self.targets.append(len(known_classes))

    def __getitem__(self, index):
        return self.data[index], self.targets[index]

    def __len__(self):
        return len(self.data)


def get_federated_data(num_clients=5, known_count=6, seed=66):
    np.random.seed(seed)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    trainset = torchvision.datasets.CIFAR10(root='./cifar10-python', train=True,
                                            download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(root='./cifar10-python', train=False,
                                           download=True, transform=transform)
    all_classes = np.arange(10)
    np.random.shuffle(all_classes)
    known_classes = sorted(all_classes[:known_count].tolist())
    indices = np.arange(len(trainset))
    np.random.shuffle(indices)
    client_indices = np.array_split(indices, num_clients)
    loaders = []
    for idx_list in client_indices:
        subset = Subset(trainset, idx_list)
        loaders.append(DataLoader(CIFAR10FedOSR(subset, known_classes), batch_size=CONFIG['batch_size'], shuffle=True))
    test_close = DataLoader(CIFAR10FedOSR(testset, known_classes, is_open=False), batch_size=CONFIG['batch_size'])
    test_open = DataLoader(CIFAR10FedOSR(testset, known_classes, is_open=True), batch_size=CONFIG['batch_size'])
    return loaders, test_close, test_open, known_classes


# %%
#  [Model: PrototypeCNN (256-dim features)]
class PrototypeCNN(nn.Module):
    def __init__(self, num_classes=6, input_channels=3):
        super(PrototypeCNN, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.conv6 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn6 = nn.BatchNorm2d(128)
        self.conv7 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn7 = nn.BatchNorm2d(256)
        self.conv8 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn8 = nn.BatchNorm2d(256)
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        # Anchors for known classes
        self.fc = nn.Linear(256, num_classes, bias=False)

    def pre2block(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool(x)
        x = self.dropout(x)
        return x

    def latter_forward(self, x):
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = self.pool(x)
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.bn8(self.conv8(x)))
        x = self.pool(x)
        x = self.global_pool(x)
        feat = x.view(x.size(0), -1)
        return F.normalize(feat, p=2, dim=1)

    def forward(self, x):
        feat = self.get_features(x)
        logits = self.fc(feat)
        return logits, feat

    def get_features(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool(x)
        x = self.dropout(x)
        x = F.relu(self.bn5(self.conv5(x)))
        x = F.relu(self.bn6(self.conv6(x)))
        x = self.pool(x)
        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.bn8(self.conv8(x)))
        x = self.pool(x)
        x = self.global_pool(x)
        feat = x.view(x.size(0), -1)
        return F.normalize(feat, p=2, dim=1)


# %%
#  [Losses and Helpers]
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        sim = torch.matmul(features, features.T) / self.temperature
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(features.device)
        logits_mask = torch.ones_like(mask) - torch.eye(mask.size(0), device=features.device)
        mask *= logits_mask
        exp_sim = torch.exp(sim) * logits_mask
        log_prob = sim - torch.log(exp_sim.sum(dim=1, keepdim=True) + 1e-12)
        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-12)
        return -mean_log_prob_pos.mean()


def compute_class_means(model, loader, device, num_classes):
    """Compute class centroids from training data."""
    model.eval()
    feats, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            _, f = model(x.to(device))
            feats.append(f)
            labels.append(y)
    feats = torch.cat(feats)
    labels = torch.cat(labels)
    class_means = []
    for k in range(num_classes):
        class_feats = feats[labels == k]
        if len(class_feats) > 0:
            class_means.append(F.normalize(class_feats.mean(0), dim=0))
        else:
            class_means.append(torch.zeros(256).to(device))
    return torch.stack(class_means)


def calculate_osr_scores(features, centroids, temperature):
    """
    Calculate multiple OSR scores for open-set detection.

    Args:
        features: (N, D) normalized feature vectors
        centroids: (K, D) class centroids/prototypes
        temperature: temperature scaling factor

    Returns:
        scores_min_dist: minimum distance to any centroid (higher = more likely open-set)
        scores_energy: negative energy score (higher = more likely open-set)
        scores_entropy: entropy of softmax probabilities (higher = more likely open-set)
    """
    # Compute distances to all centroids
    # For normalized vectors, squared L2 distance = 2 - 2*cosine_sim
    # Or we can use cosine similarity directly
    similarities = torch.matmul(features, centroids.T)  # (N, K)

    # 1. Minimum Distance Score (using cosine similarity, so we negate for "distance")
    # Lower similarity = higher distance = more likely open-set
    max_sim, _ = similarities.max(dim=1)
    scores_min_dist = (1 - max_sim).cpu().tolist()  # Convert similarity to distance-like score

    # 2. Energy Score
    # Energy = -T * log(sum(exp(logits/T)))
    # Higher energy (less negative) = more uncertain = more likely open-set
    logits = similarities / temperature
    energy = -temperature * torch.logsumexp(logits, dim=1)
    scores_energy = energy.cpu().tolist()  # Higher (less negative) = open-set

    # 3. Entropy Score
    # Higher entropy = more uniform distribution = more uncertain = more likely open-set
    probs = F.softmax(logits, dim=1)
    entropy = -torch.sum(probs * torch.log(probs + 1e-12), dim=1)
    scores_entropy = entropy.cpu().tolist()

    return scores_min_dist, scores_energy, scores_entropy


def create_curriculum_trash(inputs, progress, device):
    trash_imgs = inputs.clone()
    batch_size = inputs.size(0)
    tier = 1 if progress < 0.2 else (2 if progress < 0.5 else 3)
    for i in range(batch_size):
        img = trash_imgs[i]
        if tier == 1:  # Jigsaw
            C, H, W = img.shape
            h, w = H // 2, W // 2
            q = [img[:, :h, :w], img[:, :h, w:], img[:, h:, :w], img[:, h:, w:]]
            random.shuffle(q)
            img = torch.cat([torch.cat([q[0], q[1]], 2), torch.cat([q[2], q[3]], 2)], 1)
        elif tier == 2:  # 180 Rotation or VFlip
            img = transforms.functional.vflip(img) if random.random() < 0.5 else torch.rot90(img, 2, [1, 2])
        else:  # 90/270 Rotation
            img = torch.rot90(img, random.choice([1, 3]), [1, 2])
        trash_imgs[i] = img
    return trash_imgs.to(device)


def plot_tsne_osr(model, test_cl, test_op, round_idx, device):
    model.eval()
    projs, labels = [], []

    # Define markers for the known classes
    marker_list = ['o', 's', '^', 'D', 'P', '*', 'v', 'p', 'h', 'X']

    with torch.no_grad():
        # Get Projections for Closed-set
        for i, (x, y) in enumerate(test_cl):
            _, p = model(x.to(device))
            projs.append(p.cpu())
            labels.extend([f"Known_{j}" for j in y.numpy()])
            if i > 10:
                break

        # Get Projections for Open-set
        for i, (x, _) in enumerate(test_op):
            _, p = model(x.to(device))
            projs.append(p.cpu())
            labels.extend(["Unknown"] * x.size(0))
            if i > 10:
                break

    all_projs = torch.cat(projs).numpy()
    labels = np.array(labels)
    tsne = TSNE(n_components=2, perplexity=30, random_state=66)
    embedded = tsne.fit_transform(all_projs)

    # Setup Custom Palette and Markers
    unique_labels = sorted(list(set(labels)))
    custom_palette = {l: "blue" if "Known" in l else "orange" for l in unique_labels}
    known_labels = [l for l in unique_labels if "Known" in l]
    custom_markers = {l: marker_list[i % len(marker_list)] for i, l in enumerate(known_labels)}
    custom_markers["Unknown"] = "X"

    plt.figure(figsize=(10, 7))
    sns.scatterplot(
        x=embedded[:, 0],
        y=embedded[:, 1],
        hue=labels,
        style=labels,
        palette=custom_palette,
        markers=custom_markers,
        s=60,
        alpha=0.7
    )

    plt.title(f"t-SNE of Contrastive Projections - Round {round_idx}")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout()
    plt.show()
    # plt.savefig(f"tsne_round_{round_idx}.png", dpi=150, bbox_inches='tight')
    plt.close()


# %%
#  [Federated Training Logic]
def train_client(model, loader, config, current_round, client_idx):
    model.train()
    supcon_loss_fn = SupConLoss(temperature=config['temp'])
    optimizer = optim.SGD(model.parameters(), lr=config["closed_lr"], momentum=0.9, weight_decay=5e-4)
    progress = current_round / config['global_rounds']

    for epoch in range(config["local_epochs"]):
        for x, y in loader:
            x, y = x.to(config["device"]), y.to(config["device"])
            optimizer.zero_grad()
            logits, feat = model(x)

            # Known Training
            loss_ce = F.cross_entropy(logits, y)
            loss_sc = supcon_loss_fn(feat, y)

            # Trash Training (Entropic Rejection)
            lt = model.pre2block(x)
            mx_lt = create_curriculum_trash(lt, progress, config['device'])
            trash_feat = model.latter_forward(mx_lt)
            # trash_x = create_curriculum_trash(x, progress, config['device'])
            # _, trash_feat = model(trash_x)
            anchors = F.normalize(model.fc.weight, dim=1)
            trash_logits = torch.matmul(trash_feat, anchors.T) / config['temp']
            uniform_target = torch.full_like(trash_logits, 1.0 / config['known_classes'])
            loss_trash = F.kl_div(F.log_softmax(trash_logits, dim=1), uniform_target, reduction='batchmean')

            loss = loss_ce + (config['lambda_contrastive'] * loss_sc) + (config['lambda_trash'] * loss_trash)
            loss.backward()
            optimizer.step()


def federated_communication_fedavg(server_model, client_models):
    new_global = copy.deepcopy(client_models[0].state_dict())
    for key in new_global.keys():
        for i in range(1, len(client_models)):
            new_global[key] += client_models[i].state_dict()[key]
        if torch.is_floating_point(new_global[key]):
            new_global[key] /= len(client_models)
        else:
            new_global[key] = torch.div(new_global[key], len(client_models), rounding_mode='floor')
    server_model.load_state_dict(new_global)
    return server_model


# %%
loaders, test_cl, test_op, known_list = get_federated_data()
print(f"Known classes: {known_list}")

server_model = PrototypeCNN(num_classes=CONFIG['known_classes']).to(CONFIG['device'])
client_models = [copy.deepcopy(server_model) for _ in range(CONFIG['num_clients'])]

# Initialize best metrics
best_acc = 0
best_auc_min = 0
best_auc_eng = 0
best_auc_ent = 0

for r in range(CONFIG['global_rounds']):
    print(f"\n--- Round {r + 1} ---")

    # Train clients
    for i in range(CONFIG['num_clients']):
        train_client(client_models[i], loaders[i], CONFIG, r, i + 1)

    # Aggregate models
    server_model = federated_communication_fedavg(server_model, client_models)
    for m in client_models:
        m.load_state_dict(server_model.state_dict())

    # Evaluation
    server_model.eval()
    correct, total = 0, 0
    labels_binary = []
    scores_min, scores_energy, scores_entropy = [], [], []

    # Use the FC layer weights as class prototypes/centroids
    # These are learned anchors for each known class
    with torch.no_grad():
        centroids = F.normalize(server_model.fc.weight, dim=1)  # (K, 256)

    with torch.no_grad():
        # --- Known Class Evaluation ---
        for x, y in test_cl:
            x, y = x.to(CONFIG['device']), y.to(CONFIG['device'])
            logits, feat = server_model(x)

            # Accuracy (using logits directly or nearest centroid)
            preds = logits.argmax(dim=1)
            correct += preds.eq(y).sum().item()
            total += y.size(0)

            # OSR Scores (known samples should have LOW scores)
            s_min, s_eng, s_ent = calculate_osr_scores(feat, centroids, CONFIG['temp'])
            scores_min.extend(s_min)
            scores_energy.extend(s_eng)
            scores_entropy.extend(s_ent)
            labels_binary.extend([0] * x.size(0))  # 0 = known

        # --- Open Set Evaluation ---
        for x, _ in test_op:
            x = x.to(CONFIG['device'])
            _, feat = server_model(x)

            # OSR Scores (open-set samples should have HIGH scores)
            s_min, s_eng, s_ent = calculate_osr_scores(feat, centroids, CONFIG['temp'])
            scores_min.extend(s_min)
            scores_energy.extend(s_eng)
            scores_entropy.extend(s_ent)
            labels_binary.extend([1] * x.size(0))  # 1 = unknown/open-set

    # --- Metrics ---
    acc = 100. * correct / total
    auc_min = roc_auc_score(labels_binary, scores_min) * 100
    auc_eng = roc_auc_score(labels_binary, scores_energy) * 100
    auc_ent = roc_auc_score(labels_binary, scores_entropy) * 100

    if acc > best_acc:
        best_acc = acc
    if auc_min > best_auc_min:
        best_auc_min = auc_min
    if auc_eng > best_auc_eng:
        best_auc_eng = auc_eng
    if auc_ent > best_auc_ent:
        best_auc_ent = auc_ent

    print(f"[Global Metrics] Acc: {acc:.2f}% (Best: {best_acc:.2f}%)")
    print(f"AUROC (Min Dist): {auc_min:.2f}% (Best: {best_auc_min:.2f}%)")
    print(f"AUROC (Energy):   {auc_eng:.2f}% (Best: {best_auc_eng:.2f}%)")
    print(f"AUROC (Entropy):  {auc_ent:.2f}% (Best: {best_auc_ent:.2f}%)")

    # Visualization every 10 rounds
    # if (r + 1) % 10 == 0:
    plot_tsne_osr(server_model, test_cl, test_op, r + 1, CONFIG['device'])
    # print(f"Saved t-SNE plot for round {r + 1}")

print("\n" + "=" * 50)
print("Training Complete!")
print(f"Best Accuracy: {best_acc:.2f}%")
print(f"Best AUROC (Min Dist): {best_auc_min:.2f}%")
print(f"Best AUROC (Energy): {best_auc_eng:.2f}%")
print(f"Best AUROC (Entropy): {best_auc_ent:.2f}%")