In [1]:
import os,json
import torch
import torch.nn as nn
from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torch.optim as optim
from torch.nn import CrossEntropyLoss

import sys
sys.path.append("/root/autodl-tmp/CLIP")
import clip

## 定义模型

In [2]:
import math
device = torch.device("cuda")
class Classifier(nn.Module):
    def __init__(self, img_feature_h = 7, img_feature_w = 7):
        super(Classifier, self).__init__()

        self.img_feature_h = img_feature_h
        self.img_feature_w = img_feature_w

        self.gpt_embedding_size = 768  # self.gpt.transformer.wte.weight.shape[1]

        self.d_model = self.gpt_embedding_size
        # position embedding：
        l = 50
        self.l_embedding = nn.Embedding(l, int(self.d_model))

        self.w_embedding = nn.Embedding(img_feature_w, int(self.d_model / 2))
        self.h_embedding = nn.Embedding(img_feature_h, int(self.d_model / 2))
        self.temporal_embedding = nn.Embedding(2, int(self.d_model))

        encoder_self_layer = nn.TransformerEncoderLayer(1 * self.d_model, nhead=8,
                                                        dim_feedforward=int(4 * self.d_model))
        self.transformer_encoder = nn.TransformerEncoder(encoder_self_layer, num_layers=2)

        encoder_self_layer_classifier = nn.TransformerEncoderLayer(2 * self.d_model, nhead=8,
                                                                   dim_feedforward=int(4 * self.d_model))
        self.transformer_encoder_classifier = nn.TransformerEncoder(encoder_self_layer_classifier, num_layers=3)

        decoder_layer = nn.TransformerDecoderLayer(self.d_model, nhead=8, dim_feedforward=self.d_model * 2)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, 1)

        self.classifier_projection = nn.Linear(2 * self.d_model, 2)

        # cls_token
        scale = self.d_model ** -0.5
        self.class_embedding_classifier_changeflag = nn.Parameter(scale * torch.randn(1, 2 * self.d_model))
        # cls_token
        self.class_embedding_A = nn.Parameter(scale * torch.randn(1, self.d_model))
        self.class_embedding_B = nn.Parameter(scale * torch.randn(1, self.d_model))


        self.conv_dif = nn.Sequential(
            nn.Conv2d(self.d_model, int(self.d_model / 2), kernel_size=3),
            # nn.LayerNorm(int(outchannel/2),dim=1),
            nn.BatchNorm2d(int(self.d_model / 2)),
            nn.ReLU(inplace=True),
            nn.AdaptiveMaxPool2d((1, 1))
        )
        self.linear_dif = nn.Linear(int(self.d_model / 2), 2)

        self.pre_linear = nn.Linear(self.gpt_embedding_size, self.d_model)


    def position_embedding_1D_func(self, embedding_text):
        batch = embedding_text.shape[0]
        Len_feat = embedding_text.shape[1]

        pos_l = torch.arange(Len_feat, device=device).to(device)

        position_embedding = self.l_embedding(pos_l)

        position_embedding = position_embedding.unsqueeze(0).repeat(batch, 1, 1, 1)  # (batch, l, d_model)
        position_embedding = position_embedding.view(batch, -1, self.d_model)
        embedding_text = embedding_text + position_embedding  # NLD

        return embedding_text

    def position_embedding_2D_func(self, img_feat_A, img_feat_B):
        batch = img_feat_B.shape[0]
        Len_feat = img_feat_B.shape[1]
        h = int(math.sqrt(Len_feat))
        w = h
        pos_w = torch.arange(w, device=device).to(device)
        pos_h = torch.arange(h, device=device).to(device)
        embed_w = self.w_embedding(pos_w)
        embed_h = self.h_embedding(pos_h)
        position_embedding = torch.cat([embed_w.unsqueeze(0).repeat(h, 1, 1),
                                        embed_h.unsqueeze(1).repeat(1, w, 1)],
                                       dim=-1)
        position_embedding = position_embedding.unsqueeze(0).repeat(batch, 1, 1, 1)  # (batch, h, w, d_model)
        position_embedding = position_embedding.view(batch, -1, self.d_model)
        img_feat_A = img_feat_A + position_embedding  # NLD
        img_feat_B = img_feat_B + position_embedding  # NLD

        return img_feat_A, img_feat_B

    def Siamese_bridge_net(self, class_embedding, img_feat):
        conc_A = torch.cat(
            [class_embedding.unsqueeze(0).expand(img_feat.shape[0], *class_embedding.shape),
             img_feat], dim=1)
        conc_A = self.transformer_encoder(conc_A.permute(1, 0, 2)).permute(1, 0, 2)  # NLD
        cls_A = conc_A[:, 0, :]  # self.cls_projection(conc_A[:, 0, :])
        img_refine = conc_A[:, 1:, :]  # NLD
        return cls_A, img_refine

    def Classifier(self, img_feat):
        # img_feat = self.pre_linear(img_feat)
        img_feat_A = img_feat[:, 0, ...]  # (N,L,768)
        img_feat_B = img_feat[:, 1, ...]
        batch = img_feat_B.shape[0]
        Len_feat = img_feat_B.shape[1]
        h = int(math.sqrt(Len_feat))


        # # # 2D image position_embedding
        img_feat_A, img_feat_B = self.position_embedding_2D_func(img_feat_A, img_feat_B)  # NLD

        img_feat = img_feat_B - img_feat_A  # torch.abs(img_feat_B-img_feat_A)#torch.cat([img_feat_A, img_feat_B],dim=-1)
        _, img_feat = self.position_embedding_2D_func(img_feat_A, img_feat)  # NLD

        img_feat = torch.cat([img_feat_A, img_feat_B], dim=-1)
        conc_A = torch.cat(
            [self.class_embedding_classifier_changeflag.unsqueeze(0).expand(img_feat.shape[0],
                                                                            *self.class_embedding_classifier_changeflag.shape),
             img_feat], dim=1)

        conc_A = self.transformer_encoder_classifier(conc_A.permute(1, 0, 2)).permute(1, 0, 2)  # NLD
        changeflag = self.classifier_projection(conc_A[:, 0, :])  # self.cls_projection(conc_A[:, 0, :])

        return changeflag

    def forward(self, featuremap):
        # bridge Network
        changeflag = self.Classifier(featuremap)
        # classifier_pre_flag
        return changeflag

In [3]:
@dataclass
class ChangeClassifierConfig:
    mm_vision_tower = "ViT-B/32"

class ChangeClassifier(nn.Module):
    def __init__(self, config):
        super(ChangeClassifier, self).__init__()
        device = device=torch.device("cuda")
        self.vision_tower, self.preprocess = clip.load(config.mm_vision_tower, device=device, jit=False)
        # freeeze vision_tower
        self.vision_tower.requires_grad_(False)

        self.classifier = Classifier()
        ckpt = torch.load("/root/autodl-tmp/GeoChat/hf-models/PromptCC/cls_model.pth.tar")
        self.classifier.load_state_dict(ckpt['model_state_dict()'], strict=False)
        self.classifier.eval()

    def forward(self, images):
        # 提取两幅图像的特征
        assert images.ndim == 5
        concat_images = torch.cat([image for image in images], dim=0)
        _, image_features = self.vision_tower.encode_image(concat_images)
        split_sizes = [image.shape[0] for image in images]
        image_features = torch.split(
            image_features, split_sizes, dim=0
        )  # b tuples of [2, N, L]
        image_features = torch.stack(image_features, dim=0)  # [b, 2, N, L]
        change_pred = self.classifier(image_features)
        return change_pred

## 定义数据集

In [4]:
@dataclass
class DataConfig:
    data_path = "/root/autodl-tmp/LEVIR-MCI-dataset/ChangeChat_classify.json"
    image_folder = "/root/autodl-tmp/LEVIR-MCI-dataset/images"
    image_processor = None

class LazySupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        config
    ):
        super(LazySupervisedDataset, self).__init__()
        list_data_dict = json.load(open(config.data_path, "r"))
        self.list_data_dict = list_data_dict
        self.processor = config.image_processor
        self.image_folder = config.image_folder

    def __len__(self):
        return len(self.list_data_dict)

    def __getitem__(self, i):
        sources = self.list_data_dict[i]
        if isinstance(i, int):
            sources = [sources]
        assert len(sources) == 1, "Don't know why it is wrapped to a list"  # FIXME
        image_file = self.list_data_dict[i]["image"]
        image_folder = self.image_folder
        processor = self.processor

        if isinstance(image_file, str):
            image_file_list = [image_file]
        elif isinstance(image_file, list):
            image_file_list = image_file

        imageList = []
        for _image_file in image_file_list:
            image = Image.open(
                (os.path.join(image_folder, _image_file)).strip()
            ).convert("RGB")
            image = processor(
                image,
            )
            imageList.append(image)

        data_dict = dict()
        # 将多幅图像拼成一个Tensor
        data_dict["images"] = torch.stack(imageList, dim=0)  # (2, c, h, w)

        # 如果有变化标签
        if "changeflag" in self.list_data_dict[i]:
            data_dict["change_labels"] = torch.tensor(
                self.list_data_dict[i]["changeflag2"]
            )
        return data_dict

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = ChangeClassifierConfig()
model = ChangeClassifier(config)
model = model.to(device)
# 创建数据集和数据加载器
dataset_config = DataConfig()
dataset_config.image_processor = model.preprocess
train_dataset = LazySupervisedDataset(dataset_config)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 验证集
dataset_config.data_path = "/root/autodl-tmp/GeoChat/load/Test_CC_gt.json"
test_dataset = LazySupervisedDataset(dataset_config)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

params = [v for v in model.parameters() if v.requires_grad]
num_params = sum([v.numel() for v in model.parameters() if v.requires_grad])/1000000
optimizer = optim.Adam(params, lr=1e-4)
criterion = CrossEntropyLoss()

In [10]:
import tqdm
def calculate_accuracy(model, data_loader, device):
    model.eval()
    predicts_all = []
    labels_all = []
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm.tqdm(data_loader):
            images, labels = batch['images'].to(device), batch['change_labels'].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            predicts_all.extend(predicted.tolist())
            labels_all.extend(labels.tolist())
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = correct / total
    return accuracy, predicts_all, labels_all

In [12]:
accuracy, predicts_all, labels_all = calculate_accuracy(model, test_loader, device)

100%|██████████| 16/16 [00:20<00:00,  1.30s/it]


In [24]:
from sklearn.metrics import accuracy_score, recall_score
predicts_all2 = [1 for _ in predicts_all]
print(f"准确率：{accuracy_score(labels_all, predicts_all)}")
print(f"召回率：{recall_score(labels_all, predicts_all)}")

准确率：0.9160186625194401
召回率：0.8578838174273858
