In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import Subset, DataLoader
from torch.cuda.amp import autocast, GradScaler
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
from torchvision.datasets import CIFAR100

# ============ 1) coarse-label version CIFAR-100 ============
class CIFAR100Coarse(CIFAR100):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.coarse_targets = self._fine_to_coarse(self.targets)

    def _fine_to_coarse(self, targets):
        fine_to_coarse = [
            4, 1, 14, 8, 0, 6, 7, 7, 18, 3, 3, 14, 9, 18, 7, 11, 3, 9, 7, 11,
            6, 11, 5, 10, 7, 6, 13, 15, 3, 15, 0, 11, 1, 10, 12, 14, 16, 9, 11, 5,
            5, 19, 8, 8, 15, 13, 14, 17, 18, 10, 16, 4, 17, 4, 2, 0, 17, 4, 18, 17,
            10, 3, 2, 12, 12, 16, 12, 1, 9, 19, 2, 10, 0, 1, 16, 12, 9, 13, 15, 13,
            16, 19, 2, 4, 6, 19, 5, 5, 8, 19, 18, 1, 2, 15, 6, 13, 8, 8, 15, 6
        ]
        return [fine_to_coarse[i] for i in targets]

    def __getitem__(self, index):
        img, _ = super().__getitem__(index)
        target = self.coarse_targets[index]
        return img, target


# In[56]:

torch.manual_seed(100)

# 1) define train / eval  transforms
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409),
                         (0.2673, 0.2564, 0.2762)),
])

eval_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4865, 0.4409),
                         (0.2673, 0.2564, 0.2762)),
])

# 
full_len = 50000
indices = torch.randperm(full_len)
train_indices = indices[:50000]
val_indices   = indices[40000:]

# 3)  coarse-label replace fine-label
train_full = CIFAR100Coarse(root='./data', train=True, download=True, transform=train_transform)
val_full   = CIFAR100Coarse(root='./data', train=True, download=True, transform=eval_transform)
test_dataset  = CIFAR100Coarse(root='./data', train=False, download=True, transform=eval_transform)

train_dataset = Subset(train_full, train_indices)
val_dataset   = Subset(val_full,   val_indices)

# 4) DataLoader 
use_cuda = torch.cuda.is_available()
common_kwargs = dict(num_workers=8, pin_memory=use_cuda)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True,  **common_kwargs)
val_loader   = DataLoader(test_dataset,   batch_size=1000, shuffle=False, **common_kwargs)
test_loader  = DataLoader(test_dataset,  batch_size=10000, shuffle=False, **common_kwargs)

print(f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}, Test size: {len(test_dataset)}")


In [None]:
class ResNet18_CIFAR(nn.Module):
    def __init__(self, num_classes=20):
        super().__init__()
        self.model = models.resnet18(weights=None)
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.model.maxpool = nn.Identity()
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)

    def forward(self, x):
        return self.model(x)


In [None]:
class CoupledModel(nn.Module):
    def __init__(self, width=32, num_classes=100, kb=1.0, in_channels=3):
        super().__init__()
        self.C  = num_classes
        self.kb = kb
        self.device_E = torch.device("cuda:5")
        self.device_S = torch.device("cuda:5")

        # 
        self.E_net = ResNet18_CIFAR(num_classes=num_classes).to(self.device_E)
        self.S_net = ResNet18_CIFAR(num_classes=num_classes).to(self.device_S)

    def forward_image(self, x_img):
        """
       
        """
        x_E = x_img.to(self.device_E, non_blocking=True)

        with torch.cuda.device(self.device_E):
            E_mat = self.E_net(x_E)

        # 
        x_S = x_img.to(self.device_S, non_blocking=True)
        with torch.cuda.device(self.device_S):
            S_mat = self.S_net(x_S)

        # 
        E_mat = E_mat.to(self.device_E, non_blocking=True).float()
        S_mat = S_mat.to(self.device_E, non_blocking=True).float()

        sub_outs = torch.stack([E_mat, S_mat], dim=1)
        return E_mat, S_mat, sub_outs
    def _normalize_T(self, T, B, device):
        """
       
        """
        if not torch.is_tensor(T):
            T = torch.tensor(T, dtype=torch.float32, device=device)

        T = T.to(device)
        if T.dim() == 0:             # number -> [1,1,1]
            T = T.view(1, 1, 1)
        elif T.dim() == 1:           # [M] -> [1,M,1]
            T = T.view(1, -1, 1)
        elif T.dim() == 2:
            if T.size(1) == 1:       # [B,1] 或 [M,1]
                if T.size(0) == B:   # [B,1] -> [B,1,1]
                    T = T.view(B, 1, 1)
                else:                # [M,1] -> [1,M,1]
                    T = T.view(1, -1, 1)
            else:                    # [B,M] -> [B,M,1]
                T = T.unsqueeze(-1)
        # If [B,M,1] return
        return T
    def forward(self, x_img, T):
        """
      
        """
        device = x_img.device
        eps = 1e-9
        B = x_img.size(0)

        #  E(x), S(x)
        E_mat, S_mat, sub_outs = self.forward_image(x_img)  # [B,C], [B,C], [B,2,C]
        # S_pos = S_mat ** 2
        S_pos = F.softplus(S_mat)

        #  T -> [B, M, 1]
        T = self._normalize_T(T, B, device)                 # [B,M,1] 或 [1,M,1]
        # [B, M, C]
        E_b = E_mat.unsqueeze(1)                            # [B,1,C]
        S_b = S_pos.unsqueeze(1)                            # [B,1,C]

        # [B, M, C]
        scores_bmc = - (E_b - T * S_b) / (self.kb * (T + eps)) - (S_b / (100.0 * self.kb))**2
        probs_bmc  = F.softmax(scores_bmc, dim=2)           # [B, M, C]

        #  [B,C]
        if scores_bmc.size(1) == 1:
            scores = scores_bmc.squeeze(1)                  # [B,C]
            probs  = probs_bmc.squeeze(1)                   # [B,C]
        else:
            scores = scores_bmc.mean(dim=1)                 # [B,C] 
            probs  = probs_bmc.mean(dim=1)                  # [B,C]

        return probs, scores, sub_outs



In [None]:
class LearnableTSet(nn.Module):
    def __init__(self, K=3, T_min=0.1, T_max=10.0):
        super().__init__()
        self.K = K
        self.T_min = T_min
        self.T_max = T_max
        self.raw_lambdas = nn.Parameter(torch.randn(K))

    def forward(self):
        lambdas = torch.sigmoid(self.raw_lambdas)  # [0,1]
        Ts = self.T_min + (self.T_max - self.T_min) * lambdas  # [K]
        return torch.cat([torch.tensor([1.0], device=Ts.device), Ts], dim=0)  # [K+1]


In [None]:
@torch.no_grad()
def infer_posterior_T(model, x, y_onehot, T_grid, loss_fn=None):
    """
    q(T|x,y) ∝ exp(-CE(p(y|x,T), y))
    x:       [N,3,32,32], y_onehot: [N,10]
    T_grid:  [M,1]
    return:  qT: [N,M]
    """
    model.eval()
    N, M = x.size(0), T_grid.size(0)

    # 
    x_rep = x.unsqueeze(1).repeat(1, M, 1, 1, 1).reshape(N * M, 3, 32, 32)
    T_rep = T_grid.view(1, M, 1).expand(N, M, 1).reshape(N * M, 1)
    y_rep = y_onehot.unsqueeze(1).repeat(1, M, 1).reshape(N * M, -1)  # [N*M, C]

    # 
    _, scores, _ = model(x_rep, T_rep)             # scores: [N*M, C]
    log_probs = F.log_softmax(scores, dim=1)       # [N*M, C]

    # 
    ce_vec = -(y_rep * log_probs).sum(dim=1)       # [N*M]

    # 
    ce_mat = ce_vec.view(N, M)
    qT = torch.softmax(-ce_mat, dim=1)             # [N, M]
    return qT




In [None]:
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler(enabled=True)

def em_train_step_optimized_T(model, x, y_onehot, T_module, optimizer, scheduler=None):
    """
    E-step + Worst-T M-step
    """
    model.train()
    T_grid = T_module()                           # [M]
    N, C = y_onehot.shape
    M = T_grid.size(0)
    device_E = model.device_E

    with autocast():
        #
        E_mat, S_mat, _ = model.forward_image(x)  # [N, C] on device_E
        S_pos = F.softplus(S_mat)
       
        T_norm = model._normalize_T(T_grid, B=N, device=device_E)  # [B,M,1]
        E_b = E_mat.unsqueeze(1)                # [N,1,C]
        S_b = S_pos.unsqueeze(1)                # [N,1,C]
        eps = 1e-9
        scores_bmc = - (E_b - T_norm * S_b) / (model.kb * (T_norm + eps)) - (S_b / (100.0 * model.kb)) ** 2
        log_probs_bmc = F.log_softmax(scores_bmc, dim=2)     # [N, M, C]

        # CE: [N, M]
        ce_bm = -(y_onehot.unsqueeze(1) * log_probs_bmc).sum(dim=2)

        # E-step: compute posterior q(T|x,y)
        qT = torch.softmax(-ce_bm.detach(), dim=1)  # [N, M],

        # M-step:
       
        lambda_sharp = 5.0
        qT = torch.softmax(-lambda_sharp * ce_bm, dim=1)
        loss = (qT * ce_bm).sum() / N # worst-T
        

    # backward & step
    optimizer.zero_grad()
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    if scheduler: scheduler.step()

    #
    return loss.item(), qT.detach()


In [None]:
@torch.no_grad()
def evaluate_accuracy_posterior_labeled(model, loader, device, T_module):
    """
    
    """
    model.eval()
    correct = total = 0

    # 
    T_grid = T_module().detach().to(device)  # [M,1]
    M = T_grid.size(0)

    for x, y in loader:
        x, y = x.to(device), y.to(device)
        N, Cimg, H, W = x.shape

        #  (x, T)
        x_rep = x.unsqueeze(1).repeat(1, M, 1, 1, 1).reshape(N*M, Cimg, H, W)
        T_rep = T_grid.view(1, M, 1).expand(N, M, 1).reshape(N*M, 1)

        # 
        probs, scores, *_ = model(x_rep, T_rep)         # [N*M, C]
        Ccls = probs.size(1)
        probs_nm  = probs.view(N, M, Ccls)              # [N, M, C]
        scores_nm = scores.view(N, M, Ccls)             # [N, M, C]

        # 
        log_probs_nm = F.log_softmax(scores_nm, dim=2)  # [N, M, C]
        y_onehot = F.one_hot(y, num_classes=Ccls).float()          # [N, C]
        ce_mat = -(y_onehot.unsqueeze(1) * log_probs_nm).sum(dim=2)  # [N, M]

        # E-step: q(T|x,y)
        qT = torch.softmax(-ce_mat, dim=1)              # [N, M]

        #  p(y|x)
        probs_marg_q = (qT.unsqueeze(-1) * probs_nm).sum(dim=1)  # [N, C]

        # 
        pred = probs_marg_q.argmax(dim=1)
        correct += (pred == y).sum().item()
        total   += N

    return correct / total



In [None]:

import os
import torch
import torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1"



device = torch.device("cuda:5" if torch.cuda.is_available() else "cpu")

width = 32
kb = 1.0
em_epochs = 201 # 
lr = 1e-3

model = CoupledModel(width=width, num_classes=20, kb=kb, in_channels=3).to(device)
T_module = LearnableTSet(K=4, T_min=0.1, T_max=10.0).to(device)                     # 
# optimizer = torch.optim.Adam(list(model.parameters()) + list(T_module.parameters()), lr=lr)
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=501)
params = [
    {"params": model.E_net.parameters(), "lr": 0.01,"weight_decay": 5e-4},
    {"params": model.S_net.parameters(), "lr": 0.4,"weight_decay": 5e-4},
    {"params": T_module.parameters(), "lr": 0.001, "weight_decay": 0.0},  # 避免对 T_module 做 L2 正则
]
optimizer = torch.optim.SGD(params,momentum=0.9)
                           

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
scaler = GradScaler(enabled=torch.cuda.is_available())

train_losses = []
val_accuracies = []
T_records = []  # 
for epoch in range(em_epochs):
    model.train()
    total_loss = 0.0

    for b, (x, y) in enumerate(train_loader):
        x = x.to(device)                                   # [N, 3, 32, 32]
        y_onehot = F.one_hot(y, num_classes=100).float().to(device)  # [N,10]

        # EM + AMP  + Learnable T
        loss, qT = em_train_step_optimized_T(model, x, y_onehot, T_module, optimizer, scheduler)
        total_loss += loss

    # 
    avg_loss = total_loss / len(train_loader)
    train_losses.append(avg_loss)

    # 
    with torch.no_grad():
        T_eval_grid = T_module()
        val_acc = evaluate_accuracy_posterior_labeled(model, val_loader, device, T_module)
        val_accuracies.append(val_acc)

        current_T = T_eval_grid.detach().cpu().numpy().flatten()
        T_records.append(current_T.copy())  # 

    #  
    if epoch % 10 == 0 or epoch == em_epochs - 1:
        print(f"Epoch {epoch}: Loss = {avg_loss:.6f}, Val Acc = {val_acc:.6f}, T = {current_T}")

print("Training complete.")


In [None]:

@torch.no_grad()
def posterior_T_labeled(model, x, y, T_module, device):
    """
    
    """
    model.eval()
    x, y = x.to(device), y.to(device)

    # 
    T_grid = T_module()                  # [M,1]
    T_grid = T_grid.to(device)
    M = T_grid.size(0)

    N, Cimg, H, W = x.shape

    #  (x, T)
    x_rep = x.unsqueeze(1).repeat(1, M, 1, 1, 1).reshape(N*M, Cimg, H, W)
    T_rep = T_grid.view(1, M, 1).expand(N, M, 1).reshape(N*M, 1)
    y_rep = y.unsqueeze(1).repeat(1, M).reshape(N*M).long()  # [N*M]

    # 
    probs, scores, *_ = model(x_rep, T_rep)   # [N*M, C]
    log_probs = F.log_softmax(scores, dim=1)

    # CE loss: -log p(y_i | x_i, T_m)
    ce_vec = F.nll_loss(log_probs, y_rep, reduction="none")  # [N*M]
    ce_mat = ce_vec.view(N, M)                               # [N, M]

    #  q(T|x,y) ∝ exp(-CE)
    qT = torch.softmax(-ce_mat, dim=1)                       # [N, M]

    # MAP
    idx_map = qT.argmax(dim=1)                               # [N]
    T_map = T_grid.view(-1)[idx_map]                       # [N]

    return qT, T_map, idx_map



In [None]:

@torch.no_grad()
def posterior_T_labeled_all(model, loader, T_module, device):
    model.eval()
    all_qT = []
    all_Tmap = []
    all_idx = []

    for x, y in loader:  # 
        x, y = x.to(device), y.to(device)

        qT, T_map, idx_map = posterior_T_labeled(model, x, y, T_module, device)
        all_qT.append(qT.cpu())
        all_Tmap.append(T_map.cpu())
        all_idx.append(idx_map.cpu())

    return torch.cat(all_qT), torch.cat(all_Tmap), torch.cat(all_idx)


device = next(model.parameters()).device  # 或 torch.device("cuda" if torch.cuda.is_available() else "cpu")

#  [N_val, M]
qT_all, Tmap_all, idx_all = posterior_T_labeled_all(model, test_loader, T_module, device)

print("qT_all shape:", qT_all.shape)  # [N_val, M]
print("Tmap_all shape:", Tmap_all.shape)  # [N_val]
print(qT_all)



In [None]:

# 
max_idx = torch.argmax(qT_all, dim=1)  # [N]

#  one-hot
qT_onehot = F.one_hot(max_idx, num_classes=qT_all.shape[1]).float()

print(qT_onehot)
# qT_onehot: [N, M]
counts = qT_onehot.sum(dim=0)              # [M] 
freqs = counts / qT_onehot.shape[0]        # 

print("Sample of each temperature:", counts)
print("Frequnency of each temperature:", freqs)


# In[15]:


test_acc = evaluate_accuracy_posterior_labeled(model, test_loader, device, T_module)
print(f"Test Accuracy: {test_acc:.6f}")


# In[16]:


import numpy as np

# 
train_losses = np.array(train_losses)
val_accuracies = np.array(val_accuracies)
freqs = np.array(freqs)
qT_all = np.array(qT_all)
# 
np.savetxt("CZnew_LT_train_losses.txt", train_losses, fmt="%.10f")

# 
np.savetxt("CZnew_LT_val_accuracies.txt", val_accuracies, fmt="%.10f")

#
np.savetxt("CZnew_LT_T.txt", T_records, fmt="%.10f")

# 
np.savetxt("CZnew_LT_freqs.txt", freqs, fmt="%.10f")

# 
np.savetxt("CZnew_LT_qT.txt", qT_all, fmt="%.4f")

