In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.models import resnet18

# Set device to GPU if available, otherwise use CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define dataset and data loader
train_dir = "D:\\PhD\\data\\train"
train_dataset = ImageFolder(train_dir, transform=transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)

# Define SWaV model
class SwAVModel(nn.Module):
    def __init__(self, base_encoder):
        super(SwAVModel, self).__init__()
        self.encoder = base_encoder
        self.projection = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )
        self.predictor = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.encoder(x)
        z = self.projection(x)
        p = F.normalize(self.predictor(z), dim=1)
        return z, p

# Define base encoder
base_encoder = resnet18(pretrained=True)
modules = list(base_encoder.children())[:-1]
base_encoder = nn.Sequential(*modules)

# Instantiate SWaV model
model = SwAVModel(base_encoder)
model = model.to(device)

# Define loss function
criterion = torch.nn.CrossEntropyLoss()

# Define optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.03, weight_decay=0.0001)

# Define SWaV parameters
temperature = 0.1
queue_size = 65536
k = 65536
num_epochs = 100

# Define queue and initialize it with random data
queue = torch.randn(queue_size, 128).cuda()
queue = F.normalize(queue, dim=1)

# Train the SWaV model
for epoch in range(num_epochs):
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        # Move data to device
        inputs, targets = inputs.to(device), targets.to(device)

        # Zero the gradients
        optimizer.zero_grad()

        # Compute embeddings and predictions
        z_i, p_i = model(inputs)
        z_i = F.normalize(z_i, dim=1)

        # Compute logits and loss
        logits = torch.matmul(p_i, queue.T) / temperature
        labels = torch.cat([targets, torch.arange(10).to(device)])
        loss = criterion(logits, labels)

        # Update queue
        with torch.no_grad():
            z_j = queue.clone().detach()
            z_j = F.normalize(z_j, dim=1)
            queue[batch_idx*32:(batch_idx+1)*32] = z_i
            queue[k:] = queue[:-k]

        # Backward pass and optimization step
        loss.backward()
        optimizer.step()

        # Print loss and accuracy every 100 batches
        if batch_idx % 100 == 0:
            with torch.no_grad():
                logits = torch.matmul(p_i, queue.T) / temperature
                labels = torch.arange(batch_idx*32, (batch_idx+1)*32).to(device)
                acc1 = (torch.argmax(logits, dim=1) == labels).float().mean()
            print(f"Epoch [{epoch}/{num_epochs}] Batch [{batch_idx}/{len(train_loader)}] Loss: {loss.item():.4f} Accuracy: {acc1:.4f}")



  from .autonotebook import tqdm as notebook_tqdm


RuntimeError: mat1 and mat2 shapes cannot be multiplied (16384x1 and 512x512)