<a href="https://colab.research.google.com/github/weichang888/AD/blob/main/%E3%80%8Cblue_ipynb%E3%80%8D%E7%9A%84%E5%89%AF%E6%9C%AC%E3%80%8D%E7%9A%84%E5%89%AF%E6%9C%AC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


###data_loader

In [2]:
from torch.utils.data import Dataset
import torch
import glob
from PIL import Image
from pathlib import Path
from torchvision import transforms
import numpy as np
class TrainDataset(Dataset):
    def __init__(self, root_dir, obj_name, transform=None, resize_shape=None):
        self.root_dir = Path(root_dir)
        self.obj_name = obj_name
        self.resize_shape=resize_shape
        self.image_names = sorted(glob.glob(root_dir + self.obj_name + "/train/*/*.png"))

        if transform is not None:
            self.transform = transform
        else:
            self.transform = transforms.Compose([])
            self.transform.transforms.append(transforms.Resize((self.resize_shape, self.resize_shape)))
            # self.transform.transforms.append(transforms.RandomHorizontalFlip())
            # self.transform.transforms.append(transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1))
            self.transform.transforms.append(transforms.ToTensor())
            self.transform.transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                    std=[0.229, 0.224, 0.225]))
    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img = Image.open(str(self.image_names[idx])).convert("RGB")
        img = self.transform(img)
        return {"image":img}

class TestDataset(Dataset):
    def __init__(self, root_dir, obj_name, transform=None, resize_shape=None):
        self.root_dir = Path(root_dir)
        self.obj_name = obj_name
        self.resize_shape=resize_shape
        self.image_names = sorted(glob.glob(root_dir + self.obj_name + "/test/*/*.png"))
        self.gt_root = "/content/drive/My Drive/AD/datasets/MVTec/" + self.obj_name + "/ground_truth/"

        if transform is not None:
            self.transform = transform
        else:
            # image preprocess
            self.transform = transforms.Compose([])
            self.transform.transforms.append(transforms.Resize((self.resize_shape, self.resize_shape)))
            self.transform.transforms.append(transforms.ToTensor())
            self.transform.transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                                    std=[0.229, 0.224, 0.225]))
            # gt preprocess
            self.gt_transform = transforms.Compose([])
            self.gt_transform.transforms.append(transforms.Resize((self.resize_shape, self.resize_shape)))
            self.gt_transform.transforms.append(transforms.ToTensor())


    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        img_path = str(self.image_names[idx])
        label = img_path.split("/")[-2]
        gt_path = self.gt_root + label + "/" + img_path.split("/")[-1][:3] + "_mask.png"
        img = Image.open(img_path).convert("RGB")
        label = img_path.split("/")[-2]
        img = self.transform(img)

        if label == "good":
            gt_img = np.array([0], dtype=np.float32)
            gt_pix = torch.zeros([1, self.resize_shape, self.resize_shape])
        else:
            gt_img = np.array([1], dtype=np.float32)
            gt_pix = self.gt_transform(Image.open(gt_path))

        return {"image":img, "label":gt_img, "gt_mask":gt_pix}

        # good : 0, anomaly : 1


###loss

In [3]:
import torch
import torch.nn as nn

def get_ano_map(feature1, feature2):
    mseloss = nn.MSELoss(reduction='none') #1*C*H*W
    mse = mseloss(feature1, feature2) #1*C*H*W
    mse = torch.mean(mse,dim=1) #1*H*W
    cos = nn.functional.cosine_similarity(feature1, feature2, dim=1)
    ano_map = torch.ones_like(cos)-cos
    loss = (ano_map.view(ano_map.shape[0],-1).mean(-1)).mean()
    return ano_map.unsqueeze(1), loss, mse.unsqueeze(1)

class CosineLoss(nn.Module):
    def __init__(self):
        super(CosineLoss, self).__init__()

    def forward(self, feature1, feature2):
        cos = nn.functional.cosine_similarity(feature1, feature2, dim=1)
        ano_map = torch.ones_like(cos) - cos
        loss = (ano_map.view(ano_map.shape[0],-1).mean(-1)).mean()
        return loss


# x1 = torch.rand(2,10,50,50)

# x2 = torch.rand(2,10,50,50)

# cos = CosineLoss()

# print(cos(x1, x2))

###model

In [4]:
import torch
import torch.nn as nn
from torchvision.models import wide_resnet50_2


class ConvBlock(nn.Module):
    def __init__(self, in_channel, kernel_size, filters, stride):
        super(ConvBlock,self).__init__()
        F1, F2, F3 = filters
        self.stage = nn.Sequential(
            nn.Conv2d(in_channel,F1,1,stride=stride, padding=0, bias=False),
            nn.BatchNorm2d(F1),
            nn.ReLU(True),
            nn.Conv2d(F1,F2,kernel_size, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(F2),
            nn.ReLU(True),
            nn.Conv2d(F2,F3,1,stride=1, padding=0, bias=False),
            nn.BatchNorm2d(F3),
        )
        self.shortcut_1 = nn.Conv2d(in_channel, F3, 1, stride=stride, padding=0, bias=False)
        self.batch_1 = nn.BatchNorm2d(F3)
        self.relu_1 = nn.ReLU(inplace=True)

    def forward(self, X):
        X_shortcut = self.shortcut_1(X)
        X_shortcut = self.batch_1(X_shortcut)
        X = self.stage(X)
        X = X + X_shortcut
        X = self.relu_1(X)
        return X

class IndentityBlock(nn.Module):
    def __init__(self, in_channel, kernel_size, filters):
        super(IndentityBlock,self).__init__()
        F1, F2, F3 = filters
        self.stage = nn.Sequential(
            nn.Conv2d(in_channel,F1,1,stride=1, padding=0, bias=False),
            nn.BatchNorm2d(F1),
            nn.ReLU(True),
            nn.Conv2d(F1,F2,kernel_size, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(F2),
            nn.ReLU(True),
            nn.Conv2d(F2,F3,1,stride=1, padding=0, bias=False),
            nn.BatchNorm2d(F3),
        )
        self.relu_1 = nn.ReLU(True)

    def forward(self, X):
        X_shortcut = X
        X = self.stage(X)
        X = X + X_shortcut
        X = self.relu_1(X)
        return X

class ConvTransposeBlock(nn.Module):
    def __init__(self, in_channel, kernel_size, filters):
        super(ConvTransposeBlock,self).__init__()
        F1, F2, F3 = filters
        self.stage = nn.Sequential(
            nn.ConvTranspose2d(in_channel,F1,kernel_size=2,stride=2, padding=0, bias=False),
            nn.BatchNorm2d(F1),
            nn.ReLU(True),
            nn.Conv2d(F1,F2,kernel_size, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(F2),
            nn.ReLU(True),
            nn.Conv2d(F2,F3,1,stride=1, padding=0, bias=False),
            nn.BatchNorm2d(F3),
        )
        self.shortcut_1 = nn.ConvTranspose2d(in_channel,F3,kernel_size=2,stride=2, padding=0, bias=False)
        self.batch_1 = nn.BatchNorm2d(F3)
        self.relu_1 = nn.ReLU(inplace=True)

    def forward(self, X):
        X_shortcut = self.shortcut_1(X)
        X_shortcut = self.batch_1(X_shortcut)
        X = self.stage(X)
        X = X + X_shortcut
        X = self.relu_1(X)
        return X

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.wRes50 = wide_resnet50_2(pretrained=True)


    def forward(self, x):
        x = self.wRes50.conv1(x)
        x = self.wRes50.bn1(x)
        x = self.wRes50.relu(x)
        x = self.wRes50.maxpool(x)

        x = self.wRes50.layer1(x) # [1, 256, 64, 64]
        feature1 = x

        x = self.wRes50.layer2(x) # [1, 512, 32, 32]
        feature2 = x

        x = self.wRes50.layer3(x) # [1, 1024, 16, 16]
        feature3 = x

        return feature1, feature2, feature3

class OCBE(nn.Module):
    def __init__(self):
        super(OCBE, self).__init__()
        self.branch1 = nn.Sequential(nn.Conv2d(in_channels = 256, out_channels = 512, kernel_size = 3, stride = 2, padding = 1),
                                     nn.BatchNorm2d(512),
                                     nn.ReLU(inplace=True),
                                     nn.Conv2d(in_channels = 512, out_channels = 1024, kernel_size = 3, stride = 2, padding = 1),
                                     nn.BatchNorm2d(1024),
                                     nn.ReLU(inplace=True)
                                     )

        self.branch2 = nn.Sequential(nn.Conv2d(in_channels = 512, out_channels = 1024, kernel_size = 3, stride = 2, padding = 1),
                                     nn.BatchNorm2d(1024),
                                     nn.ReLU(inplace=True)
                                     )

        self.merge =nn.Sequential(nn.Conv2d(in_channels = 3072, out_channels = 1024, kernel_size = 1, stride = 1, padding = 0),
                                  nn.BatchNorm2d(1024),
                                  nn.ReLU(inplace=True)
                                  )

        self.resblock = nn.Sequential(ConvBlock(in_channel =1024, kernel_size = 3, filters=[512,512,2048], stride=2),
                                      IndentityBlock(in_channel=2048, kernel_size=3, filters=[512,512,2048]),
                                      IndentityBlock(in_channel=2048, kernel_size=3, filters=[512,512,2048])
                                      )

    def forward(self, x1, x2, x3):
        output = torch.cat((self.branch1(x1),self.branch2(x2),x3),dim=1) # [1, 3072, 16, 16]
        output = self.merge(output) # [1, 1024, 16, 16]
        # output = self.branch1(x1) + self.branch2(x2) + x3
        output = self.resblock(output) # [1, 2048, 8, 8]

        return output

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.layer3 = nn.Sequential(ConvTransposeBlock(in_channel=2048, kernel_size=3, filters=[512, 1024, 1024]),
                                    IndentityBlock(in_channel=1024, kernel_size=3, filters=[512,1024,1024]),
                                    IndentityBlock(in_channel=1024, kernel_size=3, filters=[512,1024,1024]),
                                    IndentityBlock(in_channel=1024, kernel_size=3, filters=[512,1024,1024]),
                                    IndentityBlock(in_channel=1024, kernel_size=3, filters=[512,1024,1024]),
                                    IndentityBlock(in_channel=1024, kernel_size=3, filters=[512,1024,1024]),
                                    )
        self.layer2 = nn.Sequential(ConvTransposeBlock(in_channel=1024, kernel_size=3, filters=[256, 512, 512]),
                                    IndentityBlock(in_channel=512, kernel_size=3, filters=[256, 512, 512]),
                                    IndentityBlock(in_channel=512, kernel_size=3, filters=[256, 512, 512]),
                                    IndentityBlock(in_channel=512, kernel_size=3, filters=[256, 512, 512])
                                    )
        self.layer1 = nn.Sequential(ConvTransposeBlock(in_channel=512, kernel_size=3, filters=[128, 256, 256]),
                                    IndentityBlock(in_channel=256, kernel_size=3, filters=[128, 256, 256]),
                                    IndentityBlock(in_channel=256, kernel_size=3, filters=[128, 256, 256])
                                    )


    def forward(self, x):
        x = self.layer3(x) # [1, 1024, 14, 14]
        feature3 = x
        x = self.layer2(x) # [1, 512, 28, 28]
        feature2 = x
        x = self.layer1(x) # [1, 256, 56, 56]
        feature1 = x

        return feature1, feature2, feature3

class OcbeAndDecoder(nn.Module):
    def __init__(self):
        super(OcbeAndDecoder, self).__init__()
        self.ocbe = OCBE()
        self.decoder = Decoder()
    def forward(self, e_feature1, e_feature2, e_feature3):
        x = self.ocbe(e_feature1, e_feature2, e_feature3)
        feature1, feature2, feature3 = self.decoder(x)
        return feature1, feature2, feature3

###test

In [5]:
!pip install matplotlib




In [6]:
import torch
import torch.nn as nn
import os
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
import numpy as np
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage

def get_ano_map(feature1, feature2):
    mseloss = nn.MSELoss(reduction='none')
    mse = mseloss(feature1, feature2)
    mse = torch.mean(mse, dim=1)
    cos = nn.functional.cosine_similarity(feature1, feature2, dim=1)
    ano_map = torch.ones_like(cos) - cos
    loss = (ano_map.view(ano_map.shape[0], -1).mean(-1)).mean()
    return ano_map.unsqueeze(1), loss, mse.unsqueeze(1)

def save_image(image, ano_map, gt_mask, output_dir, idx):
    fig, ax = plt.subplots(1, 3, figsize=(15, 5))

    to_pil = ToPILImage()
    ax[0].imshow(to_pil(image.cpu()))
    ax[0].set_title("Original Image")

    # 使用适当的 colormap 和插值方式
    ax[1].imshow(ano_map.cpu().numpy().squeeze(), cmap='jet', interpolation='bilinear')
    ax[1].set_title("Anomaly Map")

    ax[2].imshow(gt_mask.cpu().numpy().squeeze(), cmap='gray', interpolation='bilinear')
    ax[2].set_title("Ground Truth Mask")

    plt.savefig(os.path.join(output_dir, f"prediction_{idx}.png"))
    plt.close()

def test(obj_name, ckp_dir, data_dir, reshape_size, output_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    encoder = Encoder()
    encoder.to(device)
    ocbe_decoder = OcbeAndDecoder()

    ocbe_decoder.load_state_dict(torch.load(str(ckp_dir), map_location='cpu'))
    ocbe_decoder.to(device)

    encoder.eval()
    ocbe_decoder.eval()

    test_dataset = TestDataset(root_dir=data_dir, obj_name=obj_name, resize_shape=reshape_size)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

    test_loss_total = 0
    scores = []
    labels = []
    gt_list_px = []
    pr_list_px = []

    with torch.no_grad():
        for idx, sample_test in enumerate(test_loader):
            image, label, gt = sample_test["image"], sample_test["label"], sample_test["gt_mask"]

            gt[gt > 0.5] = 1
            gt[gt <= 0.5] = 0

            e_feature1, e_feature2, e_feature3 = encoder(image.to(device))
            d_feature1, d_feature2, d_feature3 = ocbe_decoder(e_feature1, e_feature2, e_feature3)

            ano_map1, loss1, mse1 = get_ano_map(e_feature1, d_feature1)
            ano_map2, loss2, mse2 = get_ano_map(e_feature2, d_feature2)
            ano_map3, loss3, mse3 = get_ano_map(e_feature3, d_feature3)

            ano_map1 = nn.functional.interpolate(ano_map1, size=(reshape_size, reshape_size), mode='bilinear', align_corners=True)
            ano_map2 = nn.functional.interpolate(ano_map2, size=(reshape_size, reshape_size), mode='bilinear', align_corners=True)
            ano_map3 = nn.functional.interpolate(ano_map3, size=(reshape_size, reshape_size), mode='bilinear', align_corners=True)
            s_al = (ano_map1 + ano_map2 + ano_map3).squeeze().cpu().numpy()

            s_al = gaussian_filter(s_al, sigma=2)  # 调整 sigma 参数

            gt_list_px.extend(gt.cpu().numpy().astype(int).ravel())
            pr_list_px.extend(s_al.ravel())

            score = np.max(s_al.ravel().tolist())

            scores.append(score)
            labels.append(label.numpy().squeeze())

            loss = loss1.item() + loss2.item() + loss3.item()
            test_loss_total += loss

            save_image(image[0], torch.tensor(s_al), gt[0], output_dir, idx)

    auroc_img = round(roc_auc_score(np.array(labels), np.array(scores)), 3)
    auroc_pix = round(roc_auc_score(np.array(gt_list_px), np.array(pr_list_px)), 3)
    return test_loss_total, auroc_img, auroc_pix



###train

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import os
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import datetime
# 以下导入部分假设您已经定义或安装了必要的模块和类
# from data_loader import TrainDataset
# from model import Encoder, OcbeAndDecoder
# from loss import CosineLoss
# from test import test

class Args:
    def __init__(self):
        self.obj_id = 0  # 指定对象ID
        self.bs = 16  # 批处理大小
        self.lr = 0.005  # 学习率
        self.epochs = 200  # 迭代次数
        self.gpu_id = 0  # GPU ID
        self.data_path = "/content/drive/My Drive/AD/datasets/MVTec/"  # 数据路径
        self.checkpoint_path = "/content/drive/My Drive/AD/checkpoints/"  # 检查点保存路径
        self.test_interval = 5  # 测试间隔

def train(obj_name, args):
    resize_shape = 256
    print(f"Start training {obj_name}")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if not os.path.exists(args.checkpoint_path):
        os.makedirs(args.checkpoint_path)

    cur_time = f"{datetime.datetime.now():%Y-%m-%d_%H_%M_%S}"
    run_time = f"{obj_name}_lr{args.lr}_bs{args.bs}_{cur_time}"
    writer = SummaryWriter(log_dir=f"/content/drive/My Drive/AD/logs/WRes50/{run_time}/")
    os.makedirs(f"/content/drive/My Drive/AD/checkpoints/WRes50/{run_time}", exist_ok=True)

    encoder = Encoder()
    ocbe_decoder = OcbeAndDecoder()
    encoder.to(device)
    ocbe_decoder.to(device)
    encoder.eval()

    train_dataset = TrainDataset(root_dir=args.data_path, obj_name=obj_name, resize_shape=resize_shape)
    print(f"Number of training samples: {len(train_dataset)}")
    if len(train_dataset) == 0:
        raise ValueError("The training dataset is empty. Please check the dataset path and contents.")

    # 将 num_workers 设置为 2
    train_loader = DataLoader(train_dataset, batch_size=args.bs, shuffle=True, drop_last=True, num_workers=2, persistent_workers=True, pin_memory=True, prefetch_factor=2)

    mse = nn.MSELoss()
    cos_similarity = CosineLoss()
    optimizer = torch.optim.Adam(ocbe_decoder.parameters(), betas=(0.5, 0.999), lr=args.lr)

    auroc_img_best, img_step = 0, 0
    auroc_pix_best, pix_step = 0, 0

    for step in tqdm(range(args.epochs), ascii=True):
        ocbe_decoder.train()
        train_loss_total = 0
        for idx, sample in enumerate(train_loader):
            images = sample["image"].to(device)

            e_feature1, e_feature2, e_feature3 = encoder(images)
            d_feature1, d_feature2, d_feature3 = ocbe_decoder(e_feature1, e_feature2, e_feature3)

            loss1 = cos_similarity(e_feature1, d_feature1)
            loss2 = cos_similarity(e_feature2, d_feature2)
            loss3 = cos_similarity(e_feature3, d_feature3)
            loss = loss1 + loss2 + loss3

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss_total += loss.item()

        writer.add_scalar("train_loss", train_loss_total, step)

        if args.test_interval > 0 and step % args.test_interval == 0:
            ckp_path = f"{args.checkpoint_path}WRes50/{run_time}/epoch{step}.pth"
            torch.save(ocbe_decoder.state_dict(), ckp_path)
            output_dir = f"/content/drive/My Drive/AD/predict/{obj_name}/{step}/"
            test_loss, auroc_img, auroc_pix = test(obj_name=obj_name, ckp_dir=ckp_path, data_dir=args.data_path, reshape_size=resize_shape, output_dir=output_dir)
            writer.add_scalar("test_loss", test_loss, step)
            writer.add_scalar("auroc_img", auroc_img, step)
            writer.add_scalar("auroc_pix", auroc_pix, step)

            if auroc_img <= auroc_img_best and auroc_pix <= auroc_pix_best:
                os.remove(ckp_path)

            if auroc_img > auroc_img_best:
                auroc_img_best = auroc_img
                img_step = step
            if auroc_pix > auroc_pix_best:
                auroc_pix_best = auroc_pix
                pix_step = step

    return run_time, auroc_img_best, auroc_pix_best, img_step, pix_step

def write2txt(filename, content):
    with open(filename, 'a') as f:
        f.write(str(content) + "\n")

if __name__ == "__main__":
    class Args:
        def __init__(self):
            self.obj_id = 0  # 指定对象ID
            self.bs = 16  # 批处理大小
            self.lr = 0.005  # 学习率
            self.epochs = 200  # 迭代次数
            self.gpu_id = 0  # GPU ID
            self.data_path = "/content/drive/My Drive/AD/datasets/MVTec/"  # 数据路径
            self.checkpoint_path = "/content/drive/My Drive/AD/checkpoints/"  # 检查点保存路径
            self.test_interval = 5  # 测试间隔

    args = Args()

    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
    obj_names = [
        'bottle', 'cable', 'capsule', 'carpet', 'grid', 'hazelnut',
        'leather', 'metal_nut', 'pill', 'screw', 'tile', 'toothbrush',
        'transistor', 'wood', 'zipper'
    ]

    log_txt_name = f"/content/drive/My Drive/AD/logs_txt/{datetime.datetime.now():%Y-%m-%d_%H_%M_%S}.txt"
    os.makedirs(os.path.dirname(log_txt_name), exist_ok=True)

    write2txt(log_txt_name, "log title")
    if args.obj_id == -1:
        for obj_name in obj_names:
            print(f"Training for category: {obj_name}")
            model_name, auroc_img_best, auroc_pix_best, img_step, pix_step = train(obj_name, args)
            write2txt(log_txt_name, f"{model_name} || auroc_img: {auroc_img_best} epoch: {img_step} || auroc_pix: {auroc_pix_best} epoch: {pix_step}")
    else:
        obj_name = obj_names[int(args.obj_id)]
        print(f"Training for category: {obj_name}")
        model_name, auroc_img_best, auroc_pix_best, img_step, pix_step = train(obj_name, args)
        write2txt(log_txt_name, f"{model_name} || auroc_img: {auroc_img_best} epoch: {img_step} || auroc_pix: {auroc_pix_best} epoch: {pix_step}")




Training for category: bottle
Start training bottle


Downloading: "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth" to /root/.cache/torch/hub/checkpoints/wide_resnet50_2-95faca4d.pth
100%|██████████| 132M/132M [00:00<00:00, 197MB/s]


Number of training samples: 213


  self.pid = os.fork()
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
