In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import random
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def random_state(seed=0):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
random_state(2025)

In [2]:
# from model.ff import FreqSplitMemAE as AE
# from model.v14 import UNetMemAEv13GatedDecoder as AE
# from model.v17 import DualEncoderMatchingAE as AE
from model.v21 import SwinMemAE as AE
from trainer.initing import weights_init
random_state(2025)
SIZE = 400
latent_dim = 2048
# model = AE(memory_size=SIZE, memory_dim=memory_dim).to(device)
model = AE(mem_size=SIZE,embed_dim=latent_dim).to(device)
model.apply(weights_init);


In [3]:
for name, param in model.named_parameters():
    if param.requires_grad:
        print(name, param.data.shape)
    else:
        print(name, "not requires grad")

encoder.patch.proj.weight torch.Size([2048, 1, 4, 4])
encoder.patch.proj.bias torch.Size([2048])
encoder.block1.attn.qkv.weight torch.Size([6144, 2048])
encoder.block1.attn.proj.weight torch.Size([2048, 2048])
encoder.block1.attn.proj.bias torch.Size([2048])
encoder.block2.attn.qkv.weight torch.Size([6144, 2048])
encoder.block2.attn.proj.weight torch.Size([2048, 2048])
encoder.block2.attn.proj.bias torch.Size([2048])
memory.keys not requires grad
memory.values not requires grad
decoder.decode.0.weight torch.Size([2048, 1024, 2, 2])
decoder.decode.0.bias torch.Size([1024])
decoder.decode.2.weight torch.Size([1024, 1, 2, 2])
decoder.decode.2.bias torch.Size([1])


In [4]:
EPOCHS = 500
BATCH_SIZE = 64
LEARNING_RATE = 1e-3
PATIENCE = 50

In [5]:
import shutil, os

log_dir = "./multi_log"
if os.path.exists(log_dir):
    shutil.rmtree(log_dir)
    print(f"{log_dir} has been deleted.")
else:
    print(f"{log_dir} does not exist.")

./multi_log has been deleted.


In [6]:
# Fashion MNIST dataset
trainset = datasets.FashionMNIST(
    root      = './.data/', train = True,
    download  = True,
    transform = transforms.Compose([
    transforms.ToTensor()
    ]))

testset = datasets.FashionMNIST(
    root      = './.data/', train     = False,
    download  = True,
    transform = transforms.Compose([
    transforms.ToTensor()
    ]))
SELECT_NORMAL = 2 # Set 2 class as train dataset.
trainset.data = trainset.data[trainset.targets == SELECT_NORMAL]
trainset.targets = trainset.targets[trainset.targets == SELECT_NORMAL] # Set 2 class as train dataset.

test_label = [2,4,6] # Define actual test class that we use
actual_testdata = torch.isin(testset.targets, torch.tensor(test_label))
testset.data = testset.data[actual_testdata]
testset.targets = testset.targets[actual_testdata]
train_data_size = len(trainset)
test_data_size = len(testset)

print("Train data size:", train_data_size, "Test data size:", test_data_size)
# 데이터셋을 먼저 train과 val로 나누기
n_val = int(len(trainset) * 0.1)
n_train = len(trainset) - n_val

augset, valset = torch.utils.data.random_split(trainset, [n_train, n_val], generator=torch.Generator().manual_seed(2025))

# data size check
print("Train data size:", len(augset),"Val data size:", len(valset),"Test data size:", len(testset))

Train data size: 6000 Test data size: 3000
Train data size: 5400 Val data size: 600 Test data size: 3000


In [7]:
BATCH_SIZE = 64
train_loader = DataLoader(augset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(valset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(testset, batch_size=1, shuffle=False)

In [8]:
from eval_module.eval_show import EvalDataset
show_dataset = EvalDataset()

In [9]:
from loss.losses import FlexibleLoss

loss_base = FlexibleLoss("mse", loss_weights={"mse": 1.0}, reduction="mean", epoch=100).to(device)
loss_fft = FlexibleLoss("fft", loss_weights={"fft": 1.0}, reduction="mean", epoch=100).to(device)
loss_center = FlexibleLoss("center_crop", loss_weights={"center_crop": 1.0}, reduction="mean", epoch=100).to(device)
loss_fused = FlexibleLoss(
    mode="charbonnier+ms-ssim+gradient",
    loss_weights={
        "charbonnier": 0.8,  # 부드러운 복원
        "ms-ssim": 0.15,     # 질감 유지
        "gradient": 0.05     # edge 강조
    },
    reduction="mean",
).to(device)



In [10]:
import torch
import torch.nn.functional as F
import os
from trainer.logger import LoggerMixin, GPUUsageLoggerMixin

In [11]:
def orthogonality_loss(memory):
    # memory: (num_slots, memory_dim)
    M = F.normalize(memory, dim=1)          # 정규화된 memory
    gram = torch.matmul(M, M.T)             # (slot, slot)
    I = torch.eye(gram.size(0), device=gram.device)
    return F.mse_loss(gram, I)

def training_loss(model, x, loss_fused, lambda_align=0.0, lambda_orth=0.01):
    x_hat, f_latent, f_mem = model(x, return_latent=True)

    # ✅ 1. 기본 복원 손실
    recon_loss = loss_fused(x_hat, x)

    # ✅ 2. memory vs latent 정렬 손실
    align_loss = F.mse_loss(f_mem, f_latent.detach())

    # ✅ 3. memory orthogonality 손실
    orth_loss = orthogonality_loss(model.memory)

    # 🔁 총합
    total_loss = recon_loss + lambda_align * align_loss + lambda_orth * orth_loss
    return total_loss, recon_loss, align_loss, orth_loss


In [12]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
def train(
    model,
    train_loader,
    val_loader,
    loss_fused,
    optimizer,
    scheduler=None,
    device='cuda',
    num_epochs=100,
    early_stop_patience=10,
    log_dir="./runs",
    show_dataset=None
):
    class Trainer(LoggerMixin, GPUUsageLoggerMixin):
        def __init__(self, log_dir):
            LoggerMixin.__init__(self, log_dir)
            GPUUsageLoggerMixin.__init__(self)

    trainer = Trainer(log_dir)
    trainer.start_gpu_monitor()

    best_val_loss = float('inf')
    patience_counter = 0
    best_model_state = None

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0.0
        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)

        for x, _ in train_loop:
            x = x.to(device)
            
            # loss, recon_loss, align_loss, oth_loss = training_loss(model, x, loss_fused)
            output = model(x)
            # if isinstance(output, tuple):
            #     output = output[-1]
            loss = loss_fused(output, x)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            train_loop.set_postfix(loss=loss.item())

        avg_train_loss = total_train_loss / len(train_loader)

        # Validation
        model.eval()
        total_val_loss = 0.0
        val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]", leave=False)
        with torch.no_grad():
            for x_val, _ in val_loop:
                x_val = x_val.to(device)
                output = model(x_val)
                loss = loss_fused(output, x_val)
                total_val_loss += loss.item()
                val_loop.set_postfix(val_loss=loss.item())

        avg_val_loss = total_val_loss / len(val_loader)
        trainer.log_losses(avg_train_loss, avg_val_loss, epoch)
        trainer.log_gpu_usage(epoch)

        # 🔥 이미지 로깅
        if show_dataset and epoch % 10 == 0:
            with torch.no_grad():
                sample_x, label = next(iter(DataLoader(show_dataset, batch_size=16)))
                sample_x = sample_x.to(device)
                if hasattr(model, 'T'):
                    t = torch.randint(0, model.T, (sample_x.size(0),), device=device)
                    output = model(sample_x, t)
                else:
                    output = model(sample_x)
                if isinstance(output, tuple):
                    output = output[-1]
                trainer.log_images(sample_x, label, output, epoch)

        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            best_model_state = model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= early_stop_patience:
                print("🛑 Early stopping triggered!")
                break

        if scheduler:
            scheduler.step()

    print("Training complete!")
    print(f"Best validation loss: {best_val_loss:.6f} at epoch {epoch+1}")
    model.load_state_dict(best_model_state)

    path = os.path.join(log_dir, "./weight/best_model.pth")
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path) # 가중치만 저장
    print(f"Model weights saved to {path}")
    trainer.stop_gpu_monitor()
    trainer.save_gpu_peak_to_log(log_dir)
    trainer.close_logger()
    return model


In [13]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
random_state(2025)
from model.v21 import init_memory
init_memory(model, train_loader, device=device, top_n=SIZE)
# Train the model
model = train(
    model,
    train_loader,
    val_loader,
    loss_fused=loss_fused,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=EPOCHS,
    early_stop_patience=PATIENCE,
    log_dir="./multi_log",
    show_dataset=show_dataset,  # 🔥 이미지 로깅용 샘플셋
)

[✅] Initialized memory with 400 samples and enabled training.


                                                                                  

KeyboardInterrupt: 

In [None]:
import torch
import torch.nn.functional as F

def mse_topk_loss(x_hat, x, k_ratio=0.1, reduction='mean'):
    """
    Top-k MSE Loss: 가장 큰 pixel error 상위 k%만 평균
    Args:
        x_hat: 재구성된 이미지 (B, 1, H, W)
        x: 원본 이미지 (B, 1, H, W)
        k_ratio: 상위 몇 %의 pixel을 사용할지 (ex: 0.1 → 상위 10%)
        reduction: 'mean' 또는 'none'
    Returns:
        (B,) shape의 score 벡터 또는 scalar
    """
    diff = (x - x_hat) ** 2        # (B, 1, H, W)
    diff_flat = diff.view(x.size(0), -1)  # (B, H*W)
    k = int(diff_flat.size(1) * k_ratio)
    topk_vals, _ = torch.topk(diff_flat, k, dim=1)

    score = topk_vals.mean(dim=1)  # (B,)
    return score if reduction == 'none' else score.mean()


In [None]:
def inference(model, test_loader, device,loss_fn1, loss_fn2, aplpa=0.5,beta=0.5):
    model.eval()
    all_outputs = []
    all_labels = []

    with torch.no_grad():
        for x, labels in test_loader:
            x = x.to(device)
            if hasattr(model, 'T'):
                t = torch.randint(0, model.T, (x.size(0),), device=device)
                output = model(x, t)
            else:
                output = model(x)
            if isinstance(output, tuple):
                output = output[-1]
            score = loss_fn1(output, x).mean().cpu().item()*aplpa
            score+= loss_fn2(output, x).mean().cpu().item()*beta
            label = 0 if labels[0] == 2 else 1
            all_outputs.append(score)
            all_labels.append(label)

    return all_outputs, all_labels

            

In [None]:
# loss_fn = FlexibleLoss("mse", loss_weights={"mse": 1.0}, reduction="mean", epoch=100).to(device)
loss_fn = FlexibleLoss("mse+ms-ssim",
    loss_weights={"mse": 0.6, "ms-ssim": 0.4}, reduction="mean", epoch=100).to(device)
loss_ssim = FlexibleLoss("ms-ssim", loss_weights={"ms-ssim": 1.0}, reduction="mean", epoch=100).to(device)
loss_mse = FlexibleLoss("mse", loss_weights={"mse": 1.0}, reduction="mean", epoch=100).to(device)
loss_combined = FlexibleLoss("mse+ms-ssim",
    loss_weights={"mse": 0.8, "ms-ssim": 0.2}, reduction="mean", epoch=100).to(device)
loss_fft = FlexibleLoss("fft", loss_weights={"fft": 1.0}, reduction="mean", epoch=100).to(device)
loss_grad = FlexibleLoss("gradient", loss_weights={"gradient": 1.0}, reduction="mean", epoch=100).to(device)
model.eval()    
# outputs, labels = inference(model, test_loader, device, loss_base, loss_fft, loss_center, mse_topk_loss)
outputs, labels = inference(model, test_loader, device, loss_combined, loss_grad, aplpa=1.0, beta=0.0)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# 예시 입력 (output: anomaly score, labels: 0 for normal, 1 for abnormal)
def plot_score_distribution_by_label(outputs, labels, bins=20):
    outputs = np.array(outputs)
    labels = np.array(labels)

    normal_scores = outputs[labels == 0]
    abnormal_scores = outputs[labels == 1]

    # 전체 범위 기반 1~99% x축 범위 설정
    x_min, x_max = np.percentile(outputs, [1, 99])

    plt.figure(figsize=(8, 5))
    plt.hist(normal_scores, bins=bins, alpha=0.5, label='Normal (label=0)', color='skyblue', density=True)
    plt.hist(abnormal_scores, bins=bins, alpha=0.5, label='Abnormal (label=1)', color='salmon', density=True)

    plt.xlabel("Anomaly Score")
    plt.ylabel("Density")
    plt.title("Anomaly Score Distribution by Label")
    plt.xlim(x_min, x_max)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

plot_score_distribution_by_label(outputs, labels, bins=50)

In [None]:
def thresholding(outputs, labels, threshold):
    outputs = np.array(outputs)
    labels = np.array(labels)

    # 예측 결과 계산
    predictions = (outputs > threshold).astype(int) # 0: 정상, 1: 비정상

    # 정확도 계산
    accuracy = np.mean(predictions == labels) * 100

    # F1 Score 계산
    tp = np.sum((predictions == 1) & (labels == 1))
    fp = np.sum((predictions == 1) & (labels == 0))
    fn = np.sum((predictions == 0) & (labels == 1))

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return accuracy, f1_score

threshold = 0.0018 # 적절한 threshold 값 설정
accuracy, f1_score = thresholding(outputs, labels, threshold)
print(f"Threshold: {threshold}, Accuracy: {accuracy:.2f}%, F1 Score: {f1_score:.2f}")

#### ㅎㅎ

In [None]:
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

def final_training(
    model,
    train_loader,
    loss_fused,
    optimizer,
    scheduler=None,
    device='cuda',
    num_epochs=100,
    log_dir="./runs",
    show_dataset=None
):
    class Trainer(LoggerMixin, GPUUsageLoggerMixin):
        def __init__(self, log_dir):
            LoggerMixin.__init__(self, log_dir)
            GPUUsageLoggerMixin.__init__(self)

    trainer = Trainer(log_dir)
    trainer.start_gpu_monitor()

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0.0
        train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]", leave=False)

        for x, _ in train_loop:
            x = x.to(device)

            output = model(x)
            if isinstance(output, tuple):
                output = output[-1]

            loss = loss_fused(output, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            train_loop.set_postfix(loss=loss.item())

        avg_train_loss = total_train_loss / len(train_loader)
        trainer.log_losses(avg_train_loss, avg_train_loss, epoch)
        trainer.log_gpu_usage(epoch)

        # 🔥 이미지 로깅
        if show_dataset and epoch % 10 == 0:
            with torch.no_grad():
                sample_x, label = next(iter(DataLoader(show_dataset, batch_size=16)))
                sample_x = sample_x.to(device)
                output = model(sample_x)
                if isinstance(output, tuple):
                    output = output[-1]
                trainer.log_images(sample_x, label, output, epoch)

        if scheduler:
            scheduler.step()

    print("Training complete!")
    path = os.path.join(log_dir, "./weight/final_model.pth")
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)
    print(f"✅ Final model weights saved to {path}")
    trainer.stop_gpu_monitor()
    trainer.save_gpu_peak_to_log(log_dir)
    trainer.close_logger()
    return model

# Final training with the full dataset
model = AE().to(device) # 모델 초기화
model.apply(weights_init) # 가중치 초기화
full_dataset = trainset # 전체 데이터셋 사용
train_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=True)
# 전체 데이터셋을 사용하여 모델 재학습
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)
EPOCHS = 300 # 재학습할 에폭 수 설정
# 모델 재학습
model = init_multi_feature_memory(model, train_loader, device=device, top_n=SIZE).to(device)
model = final_training(

    model,
    train_loader,
    loss_fused=loss_combined,
    optimizer=optimizer,
    scheduler=scheduler,
    device=device,
    num_epochs=EPOCHS,
    log_dir="./multi_log_final",
    show_dataset=show_dataset,  # 🔥 이미지 로깅용 샘플셋
)

In [None]:
loss_chamber = FlexibleLoss("charbonnier", loss_weights={"charbonnier": 1.0}, reduction="none", epoch=100).to(device)
loss_center_crop = FlexibleLoss("center_crop", loss_weights={"center_crop": 1.0}, reduction="none", epoch=100).to(device)
loss_com = FlexibleLoss(
    mode='charbonnier+center_crop+gradient',
    loss_weights={"charbonnier": 0.5, "center_crop": 0.4, "gradient": 0.1}
)


In [None]:
outputs, labels = inference(model, test_loader, device, loss_com,loss_ssim, aplpa=1.0, beta=0.0)
plot_score_distribution_by_label(outputs, labels, bins=50)

In [None]:
threshold = 0.038# 적절한 threshold 값 설정
accuracy, f1_score = thresholding(outputs, labels, threshold)
print(f"Threshold: {threshold}, Accuracy: {accuracy:.2f}%, F1 Score: {f1_score:.2f}")