In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import ResNet18_Weights
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import os
import numpy as np
from PIL import Image
import shutil
import matplotlib.pyplot as plt
from tqdm import tqdm

class SatelliteDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = [f for f in os.listdir(root_dir) if f.endswith(('.jpg', '.png', '.jpeg'))]
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_name).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, idx
    
class CNN(nn.Module):
    def __init__(self, feature_dim, num_clusters, pretrained=True, freeze_conv=True):
        super(CNN, self).__init__()
        self.feature_dim = feature_dim
        self.num_clusters = num_clusters
        
        # 初始化一个空的ResNet18模型
        resnet18 = models.resnet18(weights=None)
        
        if pretrained:
            # 加载你的预训练模型
            state_dict = torch.load("pretrained_resnet18.pth")
            # 移除原始模型中最后的全连接层的权重
            state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc')}
            # 加载修改后的权重到模型中
            resnet18.load_state_dict(state_dict, strict=False)
        
        # 提取除最后全连接层外的所有层
        self.features = nn.Sequential(*list(resnet18.children())[:-1])

        if freeze_conv:
            for param in self.features.parameters():
                param.requires_grad = False

        # 添加新的全连接层用于特征提取和聚类
        self.fc = nn.Sequential(
            nn.Linear(512, self.feature_dim),  # ResNet18的输出特征维度是512
            nn.ReLU(inplace=True)
        )
        self.fc_cluster = nn.Linear(self.feature_dim, self.num_clusters)

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

def compute_features(dataloader, model, num_samples, device):
    """
        Use CNN model to extract high-level features in images
        Every image in `dataloader` will be fed into `model` to extract features
    """
    model.eval()
    print("Start computing features...")
    features = torch.zeros(num_samples, model.feature_dim, device=device)
    
    with torch.no_grad():
        for i, (images, _) in enumerate(dataloader):
            batch_features = model(images.to(device))
            features[i * dataloader.batch_size: min((i + 1) * dataloader.batch_size, num_samples)] = batch_features.cpu()
            # 索引计算：i * dataloader.batch_size：计算当前批次在整个数据集中的起始索引。i 是当前的批次索引，dataloader.batch_size 是每个批次的大小。min((i + 1) * dataloader.batch_size, num_samples)：计算当前批次在整个数据集中的结束索引。(i + 1) * dataloader.batch_size 是下一个批次的起始索引，但不能超过数据集的总样本数 num_samples。
            # 切片赋值：features[...]：这是一个切片操作，表示将 features 数组中从起始索引到结束索引的部分进行赋值。batch_features.cpu()：将当前批次的特征数据从GPU内存转移到CPU内存。.cpu() 方法将张量从GPU设备转移到CPU设备。
            # 赋值操作：将 batch_features.cpu() 的结果赋值给 features 数组的相应部分。
    print("Finish computing features!")
    return features.cpu().numpy()

def optimize_cluster_number(features, k_range):
    """
    Find optimal cluster number by silhouette score
    Args:
        features: 特征矩阵
        k_range: 要尝试的k值范围
    Returns:
        optimal_k: 最优的聚类数
        silhouette_scores: 不同k值对应的轮廓系数
    """
    silhouette_scores = []
    
    print("Start optimizing cluster number...")
    for k in k_range:
        print(f"Testing k={k}...")
        kmeans = KMeans(n_clusters=k, random_state=42)
        cluster_labels = kmeans.fit_predict(features)
        
        # 计算轮廓系数
        score = silhouette_score(features, cluster_labels)
        silhouette_scores.append(score)
        print(f"silhouette score for k={k}: {score}")
    
    # 获取最优k值
    optimal_k = k_range[np.argmax(silhouette_scores)]
    
    # 绘制轮廓系数与k的关系图
    plt.figure(figsize=(10, 6))
    plt.plot(k_range, silhouette_scores, 'bo-')
    plt.xlabel('Cluster_nums (k)')
    plt.xticks(k_range)
    plt.ylabel('silhouette score')
    plt.title('Inspecting the optimal cluster number for Mountain')
    plt.grid(True)
    plt.savefig('silhouette_scores_Mountain.png')
    plt.close()
    
    return optimal_k, silhouette_scores

def train_deep_cluster(dataset, clusterloader, trainloader,
                       num_clusters=50, num_epochs=10, lr=0.0001):
    # 可以改动num_clusters，如果聚类效果不好
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNN(feature_dim=512, num_clusters=num_clusters, pretrained=True, freeze_conv=False).to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    losses = []
    print("Start training...")
    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        features = compute_features(clusterloader, model, len(dataset), device)     # features for all images
        kmeans = KMeans(n_clusters=num_clusters, n_init=10).fit(features)           # cluster into 50 classes
        cluster_labels = kmeans.labels_                                             # get clustered pseudo labels for all images in dataset (i.e. clusterloader)

        fc_cluster = nn.Linear(model.feature_dim, num_clusters).to(device)          # model.feature_dim = 512
        fc_cluster.weight.data.normal_(0, 0.01)
        fc_cluster.bias.data.zero_()
        optimizer = optim.Adam(model.parameters(), lr=lr)
        
        model.train()
        running_loss = 0.0
        
        train_iter = tqdm(enumerate(trainloader), total=len(trainloader))
        for i, (images, indices) in train_iter:
            images, batch_labels = images.to(device), torch.tensor(cluster_labels[indices], dtype=torch.long).to(device)
            features = model(images)                          # features for current batch, [batchsize, feature_dim]
            # outputs = model.fc_cluster(features)            # cluster labels for current batch, [batchsize, num_clusters]
            outputs = fc_cluster(features)
            loss = criterion(outputs, batch_labels)       
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(trainloader)
        print(f"Loss: {epoch_loss:.4f}")
        losses.append(epoch_loss)
        
    print("Finish training!")
    torch.save(model.state_dict(), "deep_cluster_Mountain.pth")
    
if __name__ == "__main__":
    root_dir = "Classified_images/Mountain"
    cluster_dir = "Mountain_cluster_results"
    if os.path.exists(cluster_dir):
        shutil.rmtree(cluster_dir)
    os.makedirs(cluster_dir)
    
    transform = transforms.Compose([
        transforms.Resize((343, 343)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    print("Loading datasets...")
    dataset = SatelliteDataset(root_dir, transform=transform)
    clusterloader = DataLoader(dataset, batch_size=128, shuffle=False)
    trainloader = DataLoader(dataset, batch_size=128, shuffle=True)
    print("Finished loading datasets")

    # if the model doesn't exists, starting training
    if not os.path.exists("deep_cluster_Mountain.pth"):
        train_deep_cluster(dataset, clusterloader, trainloader, num_clusters=20)
    
    # load the model
    print("Loading trained model...")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = CNN(feature_dim=512, num_clusters=20, pretrained=False, freeze_conv=False).to(device)
    model.load_state_dict(torch.load("deep_cluster_Mountain.pth", weights_only=True))

    # extract features
    features = compute_features(clusterloader, model, len(dataset), device)
    
    # check optimal cluster number
    k_range = range(3, 21)
    optimal_k, silhouette_scores = optimize_cluster_number(features, k_range)
    print(f"optimal cluster number: {optimal_k}")

    # cluster all images into optimal cluster number
    kmeans = KMeans(n_clusters=optimal_k).fit(features)
    final_cluster_labels = kmeans.labels_

    # save images to corresponding cluster folders
    for i in range(optimal_k):
        cluster_subdir = os.path.join(cluster_dir, f"cluster_{i}")
        os.makedirs(cluster_subdir)

    for idx, label in enumerate(final_cluster_labels):
        img_name = dataset.images[idx]
        src_path = os.path.join(root_dir, img_name)
        dst_path = os.path.join(cluster_dir, f"cluster_{label}", img_name)
        shutil.copy2(src_path, dst_path)

    print(f"cluster results saved in {cluster_dir}")



Loading datasets...
Finished loading datasets


  state_dict = torch.load("pretrained_resnet18.pth")


Start training...
Epoch 1/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [07:32<00:00,  3.68s/it]


Loss: 1.6764
Epoch 2/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [08:16<00:00,  4.04s/it]


Loss: 1.5588
Epoch 3/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [07:55<00:00,  3.87s/it]


Loss: 1.3668
Epoch 4/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [07:41<00:00,  3.75s/it]


Loss: 1.3633
Epoch 5/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [07:50<00:00,  3.83s/it]


Loss: 1.3310
Epoch 6/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [07:49<00:00,  3.81s/it]


Loss: 1.2667
Epoch 7/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [08:03<00:00,  3.93s/it]


Loss: 1.2049
Epoch 8/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [07:18<00:00,  3.56s/it]


Loss: 1.3005
Epoch 9/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [07:41<00:00,  3.76s/it]


Loss: 1.1763
Epoch 10/10
Start computing features...
Finish computing features!


100%|██████████| 123/123 [08:05<00:00,  3.95s/it]


Loss: 1.1536
Finish training!
Loading trained model...
Start computing features...
Finish computing features!
Start optimizing cluster number...
Testing k=3...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=3: 0.26088669896125793
Testing k=4...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=4: 0.30297863483428955
Testing k=5...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=5: 0.31838294863700867
Testing k=6...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=6: 0.34452369809150696
Testing k=7...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=7: 0.3673427999019623
Testing k=8...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=8: 0.385699063539505
Testing k=9...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=9: 0.38128748536109924
Testing k=10...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=10: 0.3758222460746765
Testing k=11...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=11: 0.3715251684188843
Testing k=12...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=12: 0.37781214714050293
Testing k=13...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=13: 0.3685828149318695
Testing k=14...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=14: 0.36701828241348267
Testing k=15...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=15: 0.35670366883277893
Testing k=16...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=16: 0.36327508091926575
Testing k=17...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=17: 0.3667709231376648
Testing k=18...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=18: 0.3565512001514435
Testing k=19...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=19: 0.37154677510261536
Testing k=20...


  super()._check_params_vs_input(X, default_n_init=10)


silhouette score for k=20: 0.3678586483001709
optimal cluster number: 8


  super()._check_params_vs_input(X, default_n_init=10)


cluster results saved in Mountain_cluster_results
