In [1]:
! pip install ftfy regex tqdm -q
! pip install git+https://github.com/openai/CLIP.git -q

In [2]:
import clip
import torch
import torchvision
import numpy as np
import tqdm
import os

In [3]:
clip.available_models()

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load('RN50', device)
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

100%|███████████████████████████████████████| 244M/244M [00:05<00:00, 46.7MiB/s]


Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408


In [4]:
train_path_to_data = '/kaggle/input/siri-whu-data-set/'
train_dataset = torchvision.datasets.ImageFolder(train_path_to_data, transform=preprocess)
# 展示读入数据集的类别等信息
classnames = train_dataset.classes
print("Class names: {}".format(classnames))
print("Total number of classes: {}".format(len(classnames)))
print(train_dataset.class_to_idx) # 类别离散数字化，用0-11表示对应类
# loader = torch.utils.data.DataLoader(images, batch_size=32, num_workers=2)

templates = [
    'a centered satellite photo of {}.',
    'a centered satellite photo of a {}.',
    'a centered satellite photo of the {}.',
]

print(f"{len(classnames)} classes, {len(templates)} templates")

Class names: ['agriculture', 'commercial', 'harbor', 'idle_land', 'industrial', 'meadow', 'overpass', 'park', 'pond', 'residential', 'river', 'water']
Total number of classes: 12
{'agriculture': 0, 'commercial': 1, 'harbor': 2, 'idle_land': 3, 'industrial': 4, 'meadow': 5, 'overpass': 6, 'park': 7, 'pond': 8, 'residential': 9, 'river': 10, 'water': 11}
12 classes, 3 templates


In [5]:
batch_size = 64
_data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
def get_features(model, _data_loader, model_type = 'clip'):
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for i, (images, labels) in enumerate(_data_loader):
            if model_type == 'clip': features = model.encode_image(images.to(device))
            else: features = model.backbone(images.to(device)).flatten(start_dim=1)
            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()

train_features, train_labels = get_features(model, _data_loader)
print(len(train_features[0]))

1024


In [7]:
feature_to_path = '/kaggle/working/pretrained_features'
pseudo = ''
os.makedirs(feature_to_path, exist_ok=True)
feature_save_filename = "{}_features{}".format('siri_whu_ds',pseudo)

np.savez(
    os.path.join(feature_to_path, feature_save_filename),
    train_features,
    label_list=train_labels,
)

In [None]:
!pip install lightly -q

In [11]:
# Note: The model and training settings do not follow the reference settings
# from the paper. The settings are chosen such that the example can easily be
# run on a small dataset with a single GPU.

import torch
import torchvision
from torch import nn

from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead
from lightly.transforms.simclr_transform import SimCLRTransform


class SimCLR(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(512, 512, 128)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z


resnet = torchvision.models.resnet18()
backbone = nn.Sequential(*list(resnet.children())[:-1])
ssl_model = SimCLR(backbone)

device = "cuda" if torch.cuda.is_available() else "cpu"
ssl_model.to(device)

ssl_transform = SimCLRTransform(input_size=128, gaussian_blur=0.0)
ssl_dataset = torchvision.datasets.ImageFolder(train_path_to_data, transform=ssl_transform)
ssl_dataloader = torch.utils.data.DataLoader(
    ssl_dataset,
    batch_size=128,
    shuffle=True,
    drop_last=True,
    num_workers=4,
)

criterion = NTXentLoss()
optimizer = torch.optim.SGD(ssl_model.parameters(), lr=0.06)

print("Starting Training")
for epoch in range(10):
    total_loss = 0
    for batch in ssl_dataloader:
        x0, x1 = batch[0]
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = ssl_model(x0)
        z1 = ssl_model(x1)
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(ssl_dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Starting Training
epoch: 00, loss: 5.43341
epoch: 01, loss: 5.36741
epoch: 02, loss: 5.30327
epoch: 03, loss: 5.22993
epoch: 04, loss: 5.23047
epoch: 05, loss: 5.19741
epoch: 06, loss: 5.17265
epoch: 07, loss: 5.14395
epoch: 08, loss: 5.12276
epoch: 09, loss: 5.07642


In [12]:
train_features, train_labels = get_features(ssl_model, _data_loader, model_type = 'SSL')
print(len(train_features[0]))

TypeError: get_features() got an unexpected keyword argument 'model_type'