In [6]:
import warnings
warnings.filterwarnings("ignore")

In [7]:
import os
import torch
import torch.nn as nn
from torchvision import models, transforms
from sklearn.metrics import accuracy_score
from sklearn.preprocessing import StandardScaler
import numpy as np
from tqdm import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths for saving precomputed features or loading cached features
FEATURE_DIR = "feat_embedding_1_10"
os.makedirs(FEATURE_DIR, exist_ok=True)

# Transformation pipeline for EfficientNet-B3
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(300),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# EfficientNet-B3 model without the final layer for feature extraction
class EfficientNetB3WithoutFC(nn.Module):
    def __init__(self):
        super(EfficientNetB3WithoutFC, self).__init__()
        original_model = models.efficientnet_b3(weights=models.EfficientNet_B3_Weights.DEFAULT)
        self.features = nn.Sequential(*list(original_model.children())[:-1])  # Remove FC layer
        self.pool = nn.AdaptiveAvgPool2d((1, 1))  # Add adaptive pooling for feature extraction

    def forward(self, x):
        x = self.features(x)
        x = self.pool(x)
        return x.view(x.size(0), -1)  # Flatten the output

feature_extractor = EfficientNetB3WithoutFC().to(device)

# This is to show the model architecture used for feature extraction
feature_extractor.eval()

EfficientNetB3WithoutFC(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): MBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False)
              (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
            (1): SqueezeExcitation(
              (avgpool): AdaptiveAvgPool2d(output_size=1)
              (fc1): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1))
              (fc2): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1))
              (activation): SiLU(inplace=True)
              (s

In [8]:
# Function to extract and cache features in batches

def load_or_compute_features(images, dataset_name, transform, model, batch_size=64):
    cache_path = os.path.join(FEATURE_DIR, f"{dataset_name}_features.npy")
    if os.path.exists(cache_path):
        print(f"Loading cached features for {dataset_name}...")
        return np.load(cache_path)

    # Create DataLoader for batch processing
    dataset = torch.utils.data.TensorDataset(torch.stack([transform(img) for img in images]))
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)

    features = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc=f"Extracting features for {dataset_name}"):
            batch = batch[0].to(device)  # Access data and move to device
            feature_batch = model(batch).cpu().numpy()
            features.append(feature_batch)

    features = np.vstack(features)
    np.save(cache_path, features)  # Cache features for future reuse
    return features



# Define LwP Classifier using Mahalanobis distance
class LwPClassifierMahalanobis:
    def __init__(self):
        self.prototypes = None
        self.covariances = None
        self.classes = None

    def fit(self, X, y):
        self.classes = np.unique(y)
        self.prototypes = {}
        self.covariances = {}

        for label in self.classes:
            class_features = X[y == label]
            self.prototypes[label] = np.mean(class_features, axis=0)

            # Compute covariance matrix with shrinkage to handle singularity
            cov_matrix = np.cov(class_features, rowvar=False)
            cov_matrix += np.eye(cov_matrix.shape[0]) * 1e-6  # Regularization
            self.covariances[label] = np.linalg.inv(cov_matrix)  # Inverse covariance matrix

    def mahalanobis_distance(self, x, prototype, cov_inv):
        delta = x - prototype
        return np.sqrt(delta.T @ cov_inv @ delta)

    def predict(self, X):
        predictions = []
        for x in X:
            # Calculate Mahalanobis distance to each prototype and choose the nearest
            distances = {
                label: self.mahalanobis_distance(x, self.prototypes[label], self.covariances[label])
                for label in self.classes
            }
            predictions.append(min(distances, key=distances.get))
        return np.array(predictions)

    def update(self, X, y):
        # Update prototypes and covariance matrices with new data
        for label in self.classes:
            class_features = X[y == label]
            if label in self.prototypes:
                # Combine old and new prototypes and covariances
                old_mean = self.prototypes[label]
                old_cov_inv = self.covariances[label]
                new_mean = np.mean(class_features, axis=0)

                # Update mean using weighted average
                self.prototypes[label] = 0.85 * old_mean + 0.15 * new_mean

                # Update covariance using weighted average
                new_cov_matrix = np.cov(class_features, rowvar=False) + np.eye(class_features.shape[1]) * 1e-6
                self.covariances[label] = 0.85 * old_cov_inv + 0.15 * np.linalg.inv(new_cov_matrix)


In [9]:
# Update process_datasets_sequentially to consistently use caching
def process_datasets_sequentially(train_paths, eval_paths, lwp, scaler):
    # print("\nProcessing evaluation dataset 1...")
    eval_data_1 = torch.load(eval_paths[0])
    eval_images_1 = eval_data_1['data']
    eval_labels_1 = eval_data_1['targets']

    # Load or compute features
    eval_features_1 = load_or_compute_features(eval_images_1, "eval_1", transform, feature_extractor)

    # Evaluate accuracy on evaluation dataset 1
    eval_predictions_1 = lwp.predict(eval_features_1)
    eval_accuracy_1 = accuracy_score(eval_labels_1, eval_predictions_1)
    print(f"Accuracy on evaluation dataset 1: {eval_accuracy_1 * 100:.2f}%")

    eval_accuracies = [eval_accuracy_1]  # Store accuracies for evaluation datasets processed so far

    for i, (train_path, eval_path) in enumerate(zip(train_paths[1:], eval_paths[1:]), start=2):
        print(f"\nProcessing training dataset {i}...")
        train_data = torch.load(train_path)
        train_images = train_data['data']

        # Load or compute training features
        train_features = load_or_compute_features(train_images, f"train_{i}", transform, feature_extractor)

        # Predict pseudo-labels for the training dataset
        pseudo_labels = lwp.predict(train_features)

        # Update the LwP Classifier
        lwp.update(train_features, pseudo_labels)

        # Evaluate the updated LwP on all previous evaluation datasets
        print(f"\nEvaluating updated classifier on datasets 1 to {i-1}...")
        for j in range(1, i):  # Loop through all previous evaluation datasets
            eval_data_prev = torch.load(eval_paths[j - 1])
            eval_images_prev = eval_data_prev['data']
            eval_labels_prev = eval_data_prev['targets']

            eval_features_prev = load_or_compute_features(eval_images_prev, f"eval_{j}", transform, feature_extractor)

            eval_predictions_prev = lwp.predict(eval_features_prev)
            eval_accuracy_prev = accuracy_score(eval_labels_prev, eval_predictions_prev)
            print(f"Accuracy on evaluation dataset {j}: {eval_accuracy_prev * 100:.2f}%")

        # Load and process current evaluation dataset
        print(f"\nProcessing evaluation dataset {i}...")
        eval_data = torch.load(eval_path)
        eval_images = eval_data['data']
        eval_labels = eval_data['targets']

        eval_features = load_or_compute_features(eval_images, f"eval_{i}", transform, feature_extractor)

        # Evaluate accuracy on the current evaluation dataset
        eval_predictions = lwp.predict(eval_features)
        eval_accuracy = accuracy_score(eval_labels, eval_predictions)
        print(f"Accuracy on evaluation dataset {i}: {eval_accuracy * 100:.2f}%")

In [10]:
# Paths for train and eval datasets
train_paths = [f"dataset/part_one_dataset/train_data/{i}_train_data.tar.pth" for i in range(1, 11)]
eval_paths = [f"dataset/part_one_dataset/eval_data/{i}_eval_data.tar.pth" for i in range(1, 11)]

# Load the first training dataset
train_data_1 = torch.load(train_paths[0])
D1_images = train_data_1['data']
D1_labels = train_data_1['targets']

scaler = StandardScaler()

D1_features = load_or_compute_features(D1_images, f"train_1", transform, feature_extractor)


lwp = LwPClassifierMahalanobis()
lwp.fit(D1_features, np.array(D1_labels))

process_datasets_sequentially(train_paths, eval_paths, lwp, scaler)

Loading cached features for train_1...
Loading cached features for eval_1...
Accuracy on evaluation dataset 1: 90.96%

Processing training dataset 2...
Loading cached features for train_2...

Evaluating updated classifier on datasets 1 to 1...
Loading cached features for eval_1...
Accuracy on evaluation dataset 1: 91.24%

Processing evaluation dataset 2...
Loading cached features for eval_2...
Accuracy on evaluation dataset 2: 92.72%

Processing training dataset 3...
Loading cached features for train_3...

Evaluating updated classifier on datasets 1 to 2...
Loading cached features for eval_1...
Accuracy on evaluation dataset 1: 91.00%
Loading cached features for eval_2...
Accuracy on evaluation dataset 2: 93.04%

Processing evaluation dataset 3...
Loading cached features for eval_3...
Accuracy on evaluation dataset 3: 91.80%

Processing training dataset 4...
Loading cached features for train_4...

Evaluating updated classifier on datasets 1 to 3...
Loading cached features for eval_1...