In [None]:
import torch
import os
import random
from sklearn.model_selection import train_test_split
from shutil import copyfile
import torch.nn as nn
import torchvision
import torch.optim as optim
import torchvision.models as models
from tqdm import tqdm
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision import datasets
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import pandas as pd
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_fscore_support as score
from collections import Counter
import time

In [None]:
# Define a new transform with additional data augmentations
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

In [None]:
dataset_dir_train = os.path.join('')

In [None]:
train_dataset = datasets.ImageFolder(dataset_dir_train, transform=transform)

In [None]:
class_names = os.listdir(dataset_dir_train)
num_classes = len(class_names)

In [None]:
num_classes

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = models.convnext_base(pretrained=True)
model.classifier[2]=nn.Linear(1024,num_classes)

model = model.to(device)

In [None]:
# Define the data loaders for training, validation, and testing
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=False, pin_memory=True)

In [None]:
model.load_state_dict(torch.load(''))

In [None]:
def evaluate_model(model, train_loader, device):
    model.eval()
    class_results = {cls: {'correct': [], 'incorrect': [], 'features': []} for cls in class_names}
    all_features = {cls: [] for cls in class_names}

    def extract_features(model, x):
        x = model.features(x)
        return nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)

    with torch.no_grad():
        for images, labels in tqdm(train_loader, desc="Evaluating"):
            images = images.to(device)
            labels = labels.to(device)

            features = extract_features(model, images)

            outputs = model(images)
            probabilities = torch.softmax(outputs, dim=1)

            _, predicted = torch.max(probabilities, dim=1)
            _, top5_indices = torch.topk(probabilities, k=3, dim=1)

            for i, label in enumerate(labels):
                true_class = class_names[label.item()]
                pred_class = class_names[predicted[i].item()]

                class_results[true_class]['correct'].append(predicted[i] == label)
                class_results[true_class]['features'].append(features[i].cpu())

                if predicted[i] == label:
                    incorrect_classes = [class_names[idx.item()] for idx in top5_indices[i, 1:]]

                    incorrect_features = []
                    for cls in incorrect_classes:
                        incorrect_features.extend(all_features[cls])

                    if incorrect_features:
                        target_vector = torch.stack(incorrect_features).mean(0)
                        class_results[true_class]['target_vector'] = target_vector
                else:
                    class_results[true_class]['incorrect'].append(features[i].cpu())

                all_features[true_class].append(features[i].cpu())

    target_feature_vectors = {}
    for cls in class_names:
        accuracy = sum(class_results[cls]['correct']) / len(class_results[cls]['correct'])

        correct_features = [f for f, c in zip(class_results[cls]['features'], class_results[cls]['correct']) if c]
        if correct_features:
            target_vector = torch.stack(correct_features).mean(0)
            target_feature_vectors[cls] = target_vector

        for i, feature in enumerate(class_results[cls]['incorrect']):
            class_results[cls]['incorrect'][i] = target_feature_vectors[cls]

    return target_feature_vectors

In [None]:
target_vectors = evaluate_model(model, train_loader, device)

# Save target feature vectors
torch.save(target_vectors, '')