In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import ToTensor
from torchvision.datasets import CIFAR10
from torchvision.models import resnet50
from torch.utils.data import DataLoader
import os
from torchvision import transforms
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = ToTensor()
dataset = CIFAR10(root = './cifar10', train = True, transform = transform, download=True)
loader = DataLoader(dataset, batch_size=64, shuffle=True)

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
C = 128
N = loader.batch_size
K = 1024

Files already downloaded and verified


In [3]:
def get_resnet50(output_dim):
    model = resnet50(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, output_dim)
    return model

In [4]:
f_q = get_resnet50(C).to(device)
f_k = get_resnet50(C).to(device)
f_k.load_state_dict(f_q.state_dict())

queue = torch.randn(C, K).to(device)
queue_ptr = 0

In [5]:
m = 0.99
optimizer = optim.Adam(f_q.parameters(), lr = 0.001)

def aug(x):
    return x + 0.1 * torch.randn_like(x)

def info_nce_loss(q, k, queue, temperature = 0.07):
    q = nn.functional.normalize(q, dim = 1, p = 2)
    k = nn.functional.normalize(k, dim = 1, p = 2)
    queue = nn.functional.normalize(queue, dim = 1, p = 2)

    positive_similarity = torch.bmm(q.view(N, 1, C), k.view(N, C, 1))
    negative_similarity = torch.mm(q, queue)

    logits = torch.cat([positive_similarity.squeeze(-1), negative_similarity], dim = 1)

    labels = torch.zeros(N, dtype=torch.long).to(device)

    loss = nn.CrossEntropyLoss()(logits/temperature, labels)
    return loss





In [6]:
num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0.0
    for x, _ in loader:
        if x.size(0) < N:
            continue  # 跳过大小小于 batch_size 的批次
        
        x = x.to(device)
        # print(x.shape)
        x_q = aug(x)
        x_k = aug(x)
        q = f_q(x_q)
        k = f_k(x_k)
        # print(q.shape)

        k = k.detach()

        loss = info_nce_loss(q, k, queue)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            for param_q, param_k in zip(f_q.parameters(), f_k.parameters()):
                param_k.data = param_k.data * m + param_q.data * (1-m)

            batch_size = k.size(0)
            # print(queue[:, queue_ptr:queue_ptr + batch_size].shape, k.T.shape, batch_size, queue_ptr, queue_ptr + batch_size)
            queue[:, queue_ptr:queue_ptr + batch_size] = k.T[:, :batch_size]
            queue_ptr = (queue_ptr + batch_size) % K

        total_loss += loss.item() 

    average_loss = total_loss / len(loader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {average_loss}')

    if epoch == 0: 
        best_loss = average_loss
    if average_loss < best_loss:
        best_loss = average_loss
        torch.save(f_q.state_dict(), 'best_model.pth')


torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])
torch.Size([64, 3, 32, 32])
torch.Size([64, 128])


KeyboardInterrupt: 

In [None]:
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision.models import resnet50
import PIL

transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=3),  # 将图像转换为3通道
    transforms.Resize((32, 32), interpolation=PIL.Image.LANCZOS),  # 调整大小
    transforms.ToTensor(),
])

# Load the MNIST dataset
mnist_transform = ToTensor()
mnist_dataset = MNIST(root='./mnist', train=True, transform=transform, download=True)
mnist_loader = DataLoader(mnist_dataset, batch_size=64, shuffle=True)

# Load the pre-trained ResNet50 model as an encoder
encoder = get_resnet50(C).to(device)
encoder.load_state_dict(torch.load('best_model.pth'))

# Define a simple classifier (you can adjust this based on your specific needs)
class MLPClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim1, hidden_dim2, output_dim):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim1)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim1, hidden_dim2)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(hidden_dim2, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

# Check the output dimension of the encoder
with torch.no_grad():
    sample_input = torch.randn(64, 3, 32, 32).to(device)
    encoder_output = encoder(sample_input)
    encoder_output_dim = encoder_output.size(1)

# Create the classifier
classifier = MLPClassifier(encoder_output_dim,512,256,10).to(device)

# Print the model summary (optional)

# Define the optimizer and loss function
classifier_optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()


  "Argument interpolation should be of type InterpolationMode instead of int. "


In [None]:
# Training loop
num_epochs = 5
for epoch in range(num_epochs):
    total_loss = 0.0
    for x, labels in mnist_loader:
        x = x.to(device)
        labels = labels.to(device)

        # Encode the images using the pre-trained encoder
        with torch.no_grad():
            encoded_images = encoder(x)

        # Forward pass through the classifier
        logits = classifier(encoded_images)

        # Compute the loss and update the model
        loss = criterion(logits, labels)
        classifier_optimizer.zero_grad()
        loss.backward()
        classifier_optimizer.step()

        total_loss += loss.item()

    average_loss = total_loss / len(mnist_loader)
    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {average_loss}')

# Save the trained classifier
torch.save(classifier.state_dict(), 'trained_classifier.pth')


Epoch [1/5], Loss: 2.344380499203322
Epoch [2/5], Loss: 2.226782302358257
Epoch [3/5], Loss: 2.21560400305018
Epoch [4/5], Loss: 2.203023105160768
Epoch [5/5], Loss: 2.190199889099674


In [None]:
for x, _ in loader:
    print(x, _)

tensor([[[[0.8549, 0.8471, 0.8510,  ..., 0.7961, 0.7608, 0.7647],
          [0.8667, 0.8627, 0.8627,  ..., 0.8431, 0.8275, 0.8510],
          [0.8588, 0.8549, 0.8549,  ..., 0.8706, 0.8078, 0.8431],
          ...,
          [0.8196, 0.7647, 0.4824,  ..., 0.7333, 0.7725, 0.7647],
          [0.8235, 0.8039, 0.5686,  ..., 0.7686, 0.7725, 0.7647],
          [0.8118, 0.7843, 0.6196,  ..., 0.7569, 0.7569, 0.7569]],

         [[0.9294, 0.9216, 0.9255,  ..., 0.8784, 0.8353, 0.7412],
          [0.9412, 0.9373, 0.9373,  ..., 0.8196, 0.7451, 0.7529],
          [0.9333, 0.9294, 0.9333,  ..., 0.7137, 0.6588, 0.7686],
          ...,
          [0.9176, 0.8353, 0.4627,  ..., 0.8078, 0.8588, 0.8471],
          [0.9059, 0.8706, 0.5569,  ..., 0.8549, 0.8588, 0.8510],
          [0.8902, 0.8549, 0.5882,  ..., 0.8392, 0.8431, 0.8353]],

         [[0.9843, 0.9765, 0.9804,  ..., 0.9020, 0.8471, 0.6784],
          [0.9961, 0.9922, 0.9922,  ..., 0.7804, 0.7059, 0.7059],
          [0.9882, 0.9843, 0.9843,  ..., 0

KeyboardInterrupt: 