In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from einops import rearrange


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:

# Define Masked Autoencoder
class MaskedAutoencoder(nn.Module):
    def __init__(self, img_size=64, hidden_dim=256, mask_ratio=0.75):
        super(MaskedAutoencoder, self).__init__()
        self.mask_ratio = mask_ratio
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, hidden_dim, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(hidden_dim, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        mask = (torch.rand_like(x) > self.mask_ratio).float()
        masked_x = x * mask
        encoded = self.encoder(masked_x)
        reconstructed = self.decoder(encoded)
        return reconstructed, mask

In [4]:
# Load the saved MAE model
mae_finetune = MaskedAutoencoder().to(device)
mae_finetune.load_state_dict(torch.load("mae_pretrained.pth"))

print("Pretrained MAE model loaded successfully!")


Pretrained MAE model loaded successfully!


In [5]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms

# Define dataset directories for each class
dataset_dirs = {
    "axion": "dataset/Dataset/axion",
    "cdm": "dataset/Dataset/cdm",
    "no_sub": "dataset/Dataset/no_sub"
}


In [6]:

class LensDatasetWithLabels(Dataset):
    def __init__(self, dataset_dirs, transform=None):
        self.data = []
        self.labels = []
        self.transform = transform
        self.class_labels = {"axion": 0, "cdm": 1, "no_sub": 2}

        # Gather file paths and corresponding labels
        for class_name, dir_path in dataset_dirs.items():
            label = self.class_labels[class_name]
            for file in os.listdir(dir_path):
                if file.endswith('.npy'):
                    file_path = os.path.join(dir_path, file)
                    self.data.append(file_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data[idx]
        img = np.load(img_path, allow_pickle=True)
        
        # If the loaded array is of object type, extract the image from the first element.
        if isinstance(img, np.ndarray) and img.dtype == object:
            img = img[0]
        
        # Ensure the image is a float32 numpy array
        img = np.array(img, dtype=np.float32)
        
        # If image has more than 2 dimensions, extract the first channel (assume grayscale)
        if img.ndim > 2:
            img = img[:, :, 0]
        
        # Apply transforms (ToTensor will convert (H,W) to (1,H,W) for grayscale)
        if self.transform:
            img = self.transform(img)
            
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, label


In [7]:

# Define the transformation pipeline
data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Create the full dataset
fine_tune_dataset = LensDatasetWithLabels(dataset_dirs, transform=data_transform)

# Split the dataset: 90% for training, 10% for validation
train_size = int(0.9 * len(fine_tune_dataset))
val_size = len(fine_tune_dataset) - train_size
train_dataset, val_dataset = random_split(fine_tune_dataset, [train_size, val_size])

# Create DataLoaders for each dataset
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [8]:

# Test the DataLoaders by fetching one batch from each
print("Training batch:")
for images, labels in train_loader:
    print("Batch shape:", images.shape)  # Expected: (batch_size, 1, H, W)
    print("Labels:", labels)
    break

print("Validation batch:")
for images, labels in val_loader:
    print("Batch shape:", images.shape)  # Expected: (batch_size, 1, H, W)
    print("Labels:", labels)
    break


Training batch:
Batch shape: torch.Size([32, 1, 64, 64])
Labels: tensor([2, 2, 0, 1, 0, 0, 2, 2, 0, 1, 2, 2, 1, 2, 2, 2, 0, 1, 2, 2, 2, 2, 0, 2,
        0, 2, 0, 1, 1, 2, 1, 1])
Validation batch:
Batch shape: torch.Size([32, 1, 64, 64])
Labels: tensor([1, 2, 1, 2, 0, 1, 1, 2, 0, 2, 2, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 2, 1, 1,
        1, 1, 1, 1, 1, 2, 0, 1])


In [9]:
import torch.nn as nn
import torch.optim as optim

# Use the encoder from the pre-trained MAE model
pretrained_encoder = mae_finetune.encoder  # This is the pre-trained encoder

class Classifier(nn.Module):
    def __init__(self, encoder, hidden_dim=256, num_classes=3):
        super(Classifier, self).__init__()
        self.encoder = encoder
        # Adaptive average pooling to reduce spatial dimensions
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        # Fully connected layer for classification
        self.fc = nn.Linear(hidden_dim, num_classes)
    
    def forward(self, x):
        # Extract features using the pre-trained encoder
        features = self.encoder(x)  # Expected shape: (batch_size, hidden_dim, H, W)
        pooled = self.pool(features)  # Shape: (batch_size, hidden_dim, 1, 1)
        pooled = pooled.view(pooled.size(0), -1)  # Flatten to (batch_size, hidden_dim)
        logits = self.fc(pooled)
        return logits

# Create classifier model and move it to GPU (cuda)
classifier = Classifier(pretrained_encoder).to(device)

# Display the classifier model summary
print(classifier)


Classifier(
  (encoder): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
  )
  (pool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=256, out_features=3, bias=True)
)


In [10]:
import torch
import torch.nn as nn
import torch.optim as optim

# Assuming classifier, train_loader, and val_loader are already defined
num_epochs = 20
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-4)

for epoch in range(num_epochs):
    # Training Phase
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = classifier(images)
        loss = criterion(outputs, labels)
        
        # Backward pass and update
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    train_loss = running_loss / len(train_loader.dataset)
    train_acc = correct / total
    
    # Validation Phase
    classifier.eval()
    val_loss = 0.0
    correct_val = 0
    total_val = 0
    
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = classifier(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total_val += labels.size(0)
            correct_val += (predicted == labels).sum().item()
    
    val_loss = val_loss / len(val_loader.dataset)
    val_acc = correct_val / total_val
    
    print(f"Epoch [{epoch+1}/{num_epochs}] "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")


Epoch [1/20] Train Loss: 1.1019, Train Acc: 0.3395 | Val Loss: 1.0970, Val Acc: 0.3469
Epoch [2/20] Train Loss: 1.0988, Train Acc: 0.3445 | Val Loss: 1.0995, Val Acc: 0.3349
Epoch [3/20] Train Loss: 1.0959, Train Acc: 0.3626 | Val Loss: 1.0937, Val Acc: 0.3415
Epoch [4/20] Train Loss: 1.0922, Train Acc: 0.3765 | Val Loss: 1.0888, Val Acc: 0.4396
Epoch [5/20] Train Loss: 1.0871, Train Acc: 0.4056 | Val Loss: 1.0839, Val Acc: 0.4530
Epoch [6/20] Train Loss: 1.0802, Train Acc: 0.4277 | Val Loss: 1.0739, Val Acc: 0.5849
Epoch [7/20] Train Loss: 1.0712, Train Acc: 0.4622 | Val Loss: 1.0652, Val Acc: 0.4405
Epoch [8/20] Train Loss: 1.0604, Train Acc: 0.4814 | Val Loss: 1.0507, Val Acc: 0.4983
Epoch [9/20] Train Loss: 1.0465, Train Acc: 0.5079 | Val Loss: 1.0365, Val Acc: 0.5163
Epoch [10/20] Train Loss: 1.0329, Train Acc: 0.5285 | Val Loss: 1.0260, Val Acc: 0.5038
Epoch [11/20] Train Loss: 1.0175, Train Acc: 0.5403 | Val Loss: 1.0202, Val Acc: 0.5739
Epoch [12/20] Train Loss: 1.0023, Train A

In [None]:
torch.save(classifier.state_dict(), "MAE_classifier_model_65.pth")