# Importing Libraries

In [6]:
%load_ext autoreload
%autoreload 2
import torch
import torchvision
import numpy as np
from tqdm import tqdm
from TextToConcept import TextToConcept
import os
import torchvision.transforms as transforms
import torch.nn as nn

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
# In case if multiple GPUs
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'  # GPU I

# Check if CUDA is available and list available CUDA devices
if torch.cuda.is_available():
    print("CUDA is available. Number of GPUs available: ", torch.cuda.device_count())
    print("GPU Name:", torch.cuda.get_device_name())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

CUDA is available. Number of GPUs available:  1
GPU Name: Quadro RTX 6000
Using device: cuda


In [8]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

In [9]:
std_transform_without_normalization = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor()])


std_transform_with_normalization = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])

#
In this part, we load the DINO model. We also use ``forward_features(x)`` that takes a tensor as the input and outputs the representation (features) of input $x$ when it is passed through the model and ``get_normalizer``, which is the normalizer that the models uses to preprocess the input.

In [10]:
# Local path to the cloned dino repository
local_repo_path = '../Weights/dino_pretrain.pth'  # Replace with your local path

model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50', pretrained=True)
model.load_state_dict(torch.load(local_repo_path , weights_only=True))

def forward_features(x):
    # Assuming the model returns the last layer's output which might include a classification token or similar
    features = model(x)  # This call should be adjusted based on the actual output format
    # You might need to process the features here, e.g., selecting the right tensor or applying global pooling
    return features.squeeze()  # Adjust this based on the actual structure of your features


# Attach the custom forward method to the model
model.forward_features = forward_features

model.get_normalizer = torchvision.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
model.has_normalizer = True

Using cache found in /home/umair.nawaz/.cache/torch/hub/facebookresearch_dino_main


<b>Initiating Text-To-Concept Object</b><br>
In this section, we initiate ``TextToConcept`` object which turns the vision encoder into a model capable of integrating language and vision.

In [15]:
text_to_concept = TextToConcept(model, 'Dino')

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


<b>Loading the Linear Aligner</b><br>

In [30]:
text_to_concept.load_linear_aligner('/home/umair.nawaz/Research_Work/Submission/AgriClip/Weights/Aligned_Models/Agri_Dino_aligner_V6.pth')

### Zero-shot Classifier

In [13]:

from torchvision.datasets import ImageFolder

# Specify the directory where images are stored
directory_path = '/home/umair.nawaz/Research_Work/Main-DATA/My_Surgical/downstream-final/Banana Deficiency'

# Create the dataset using the ImageFolder class
dataset = ImageFolder(root=directory_path,
                            transform=std_transform_with_normalization)


####

# 20% of images are fairly enough.
# num_of_samples = int(0.5 * len(cifar_dataset))
# cifar_dset = torch.utils.data.Subset(cifar_dataset, np.random.choice(np.arange(len(cifar_dataset)), num_of_samples, replace=False))

####


loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=8)

# Get class name mappings
idx_to_class = {v: k for k, v in dataset.class_to_idx.items()}
# Convert dictionary values to a list, sorted by key
class_names_list = [idx_to_class[i] for i in sorted(idx_to_class.keys())]
print(class_names_list)

['boron', 'calcium', 'healthy', 'iron', 'magnesium', 'manganese', 'potassium', 'sulphur', 'zinc']


In [14]:
cifar_zeroshot_classifier = text_to_concept.get_zero_shot_classifier(class_names_list,
                                                                     prompts=['a photo contain {} deficiency'])

NameError: name 'text_to_concept' is not defined

In [34]:
# cifar_zeroshot_classifier = text_to_concept.get_zero_shot_classifier(cifar_classes,
#                                                                      prompts=['a pixelated photo of a {}'])a type of fish named {}

### Zero-shot performance on CIFAR-10
After loading CIFAR-10, we use `cifar_zeroshot_classifier(x)` to get logits of the classification problem when input $x$ is given.

In [35]:
# cifar = torchvision.datasets.CIFAR10(root='data/',
#                                      download=True,
#                                      train=False,
#                                      transform=std_transform_with_normalization)

In [36]:
# loader = torch.utils.data.DataLoader(cifar, batch_size=16, shuffle=True, num_workers=8)
correct, total = 0, 0
with torch.no_grad():
    for data in tqdm(loader):
        x, y = data[:2]
        x = x.to(device)

        try:
            # Attempt to compute the outputs
            outputs = cifar_zeroshot_classifier(x).detach().cpu()
            _, predicted = outputs.max(1)  # Find the index of the max log-probability
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
        except IndexError as e:
            print(f"Error processing batch: {e}")
            print(f"x.shape: {x.shape}, y: {y}")
            continue  # Skip this batch and continue with the next

  0%|          | 0/97 [00:00<?, ?it/s]

100%|██████████| 97/97 [00:06<00:00, 14.55it/s]


In [37]:
f'Zeroshot Accuracy: {100.*correct/total:.2f}'

'Zeroshot Accuracy: 14.19'