### Preliminaries
In this section, we import the required libraries and initialize standard transformations necessary for loading datasets. It is worth mentioning that certain models require input normalization, while others do not.

In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torchvision
import numpy as np
from tqdm import tqdm
from TextToConcept import TextToConcept
import os

In [2]:
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'  # GPU I

# Check if CUDA is available and list available CUDA devices
if torch.cuda.is_available():
    print("CUDA is available. Number of GPUs available: ", torch.cuda.device_count())
    print("GPU Name:", torch.cuda.get_device_name())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

CUDA is available. Number of GPUs available:  1
GPU Name: NVIDIA GeForce RTX 4090
Using device: cuda


In [3]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

In [4]:
std_transform_without_normalization = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor()])


std_transform_with_normalization = torchvision.transforms.Compose([
    torchvision.transforms.Resize(224),
    torchvision.transforms.CenterCrop(224),
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)])

### Resnet50
In this part, we load Resnet50 model.
In order to use ``TextToConcept`` framework, model should implement these functions/attributes:
+ ``forward_features(x)`` that takes a tensor as the input and outputs the representation (features) of input $x$ when it is passed through the model.
+ ``get_normalizer`` should be the normalizer that the models uses to preprocess the input. e.g., Resnet18, uses standard ImageNet normalizer.
+ Attribute ``has_normalizer`` should be `True` when normalizer is need for the model.

In [27]:
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')

encoder = torch.nn.Sequential(*list(model.children())[:-2])
model.forward_features = lambda x : encoder(x)
model.get_normalizer = torchvision.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
model.has_normalizer = True

Using cache found in /home/umair.nawaz/.cache/torch/hub/facebookresearch_dino_main


In [7]:
# Local path to the cloned dino repository
local_repo_path = '/home/umair.nawaz/Research_Work/Submission/AgriClip/Weights/dino_pretrain.pth'  # Replace with your local path

model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50', pretrained=True)
model.load_state_dict(torch.load(local_repo_path , weights_only=True))

# Load the dino_resnet50 model from the local repo
# model = torch.hub.load(local_repo_path, 'dino_resnet50', pretrained=False)


Using cache found in /home/umair.nawaz/.cache/torch/hub/facebookresearch_dino_main


<All keys matched successfully>

In [28]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn

# Load DINO Vision Transformer model
# model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16', pretrained=True)
# model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50', pretrained=True)

def forward_features(x):
    # Assuming the model returns the last layer's output which might include a classification token or similar
    features = model(x)  # This call should be adjusted based on the actual output format
    # You might need to process the features here, e.g., selecting the right tensor or applying global pooling
    return features.squeeze()  # Adjust this based on the actual structure of your features


# Attach the custom forward method to the model
model.forward_features = forward_features

IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

model.get_normalizer = torchvision.transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
model.has_normalizer = True

# Create a dummy input tensor
input_tensor = torch.randn(1, 3, 224, 224)  # Adjust size if needed

# Normalize the input
normalized_input = model.get_normalizer(input_tensor)

# Get features
features = model.forward_features(normalized_input)
print(features.shape)  # Check the shape of the output features


torch.Size([768])


### Linear Aligner

<b>Initiating Text-To-Concept Object</b><br>
In this section, we initiate ``TextToConcept`` object which turns the vision encoder (e.g., Resnet50) into a model capable of integrating language and vision. By doing so, we enable the utilization of certain abilities present in vision-language models.

In [29]:
text_to_concept = TextToConcept(model, 'Dino')

We can either train the aligner or load an existing one.

#### Training Linear Aligner

<b>Loading ImageNet Dataset to Train the Aligner</b><br>
We note that even $20\%$ of ImageNet training samples suffices for training an effective linear aligner. 
We refer to Appendix A of our paper for more details on sample efficiency of linear alignment.

In [82]:
# loading imagenet dataset to train aligner.
dset = torchvision.datasets.ImageFolder(root='/share/sdb/umairnawaz/My_Surgical/down-data/',
                                    #  split='train',
                                     transform=std_transform_with_normalization)

# 20% of images are fairly enough.
num_of_samples = int(0.99 * len(dset))
dset = torch.utils.data.Subset(dset, np.random.choice(np.arange(len(dset)), num_of_samples, replace=False))

In [10]:
num_of_samples

248638

<b>Training the Linear Aligner</b><br>
After loading the object, we need to train the aligner.
+ In order to train the aligner, ``train_linear_aligner`` should be called which obtains representations of the given model (e.g., Resnet50) on ``dset`` as well that of a vision-language model such as CLIP. These representations can also be loaded. Next, this function solves the linear transformation and obtain optimal alignment from model's space to vision-language space.
+ By calling the function ``save_linear_aligner``, linear aliger will be stored which can be utilized later.

In [83]:
path1 = '/share/sdb/umairnawaz/T2C/Text-to-concept/Aligned_Models/DINO/1/DINO.npy'
path2 = '/share/sdb/umairnawaz/T2C/Text-to-concept/Aligned_Models/CLIP/1/CLIP.npy'

text_to_concept.train_linear_aligner(dset,
                                     save_reps = False ,load_reps=True,path_to_model=path1, path_to_clip_model=path2, epochs = 70)


text_to_concept.save_linear_aligner('./Aligned_Models/Agri_Dino_aligner_V7_3_vitb16.pth')

Loading representations ...
Training linear aligner ...
Linear alignment: ((248638, 2048)) --> ((248638, 512)).
Initial MSE, R^2: 5.796, -0.289
Epoch number, loss: 0, 1.487
Epoch number, loss: 1, 1.223
Epoch number, loss: 2, 1.153
Epoch number, loss: 3, 1.113
Epoch number, loss: 4, 1.085
Epoch number, loss: 5, 1.064
Epoch number, loss: 6, 1.047
Epoch number, loss: 7, 1.033
Epoch number, loss: 8, 1.021
Epoch number, loss: 9, 1.011
Epoch number, loss: 10, 1.001
Epoch number, loss: 11, 0.994
Epoch number, loss: 12, 0.986
Epoch number, loss: 13, 0.980
Epoch number, loss: 14, 0.974
Epoch number, loss: 15, 0.969
Epoch number, loss: 16, 0.964
Epoch number, loss: 17, 0.959
Epoch number, loss: 18, 0.955
Epoch number, loss: 19, 0.951
Epoch number, loss: 20, 0.948
Epoch number, loss: 21, 0.944
Epoch number, loss: 22, 0.941
Epoch number, loss: 23, 0.938
Epoch number, loss: 24, 0.935
Epoch number, loss: 25, 0.933
Epoch number, loss: 26, 0.930
Epoch number, loss: 27, 0.928
Epoch number, loss: 28, 0.

<b>Loading the Linear Aligner</b><br>
We can also use an already existing linear aligner, to do so, we use the function ``load_linear_aligner``.

In [30]:
text_to_concept.load_linear_aligner('/home/umair.nawaz/Research_Work/Submission/AgriClip/Weights/Aligned_Models/Agri_Dino_aligner_V6.pth')

### Zero-shot Classifier
We note that CIFAR-10 is a <i>$10$-way</i> classification problem. 
We use prompts of the form `a pixelated of {c}` to get appropriate concepts in vision-language space.

In [31]:
# cifar_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [32]:
# cifar = torchvision.datasets.CIFAR10(root='/home/umair.nawaz/Research_Work/Main-DATA/My_Surgical/downstream/crops/dataset_22/',
#                                      download=False,
#                                      train=False,
#                                      transform=None)

from torchvision.datasets import ImageFolder

directory_path = '/home/umair.nawaz/Research_Work/Main-DATA/My_Surgical/downstream/crops/dataset_22'

# Specify the directory where CIFAR-10 images are stored
# datapath = '/home/umair.nawaz/Research_Work/Main-DATA/My_Surgical/downstream/crops/dataset_22'

# Create the dataset using the ImageFolder class
cifar_dataset = ImageFolder(root=directory_path,
                            transform=std_transform_with_normalization)


####

# 20% of images are fairly enough.
# num_of_samples = int(0.5 * len(cifar_dataset))
# cifar_dset = torch.utils.data.Subset(cifar_dataset, np.random.choice(np.arange(len(cifar_dataset)), num_of_samples, replace=False))

####


loader = torch.utils.data.DataLoader(cifar_dataset, batch_size=32, shuffle=True, num_workers=8)

# Get class name mappings
idx_to_class = {v: k for k, v in cifar_dataset.class_to_idx.items()}
# Convert dictionary values to a list, sorted by key
class_names_list = [idx_to_class[i] for i in sorted(idx_to_class.keys())]
print(class_names_list)

['boron', 'calcium', 'healthy', 'iron', 'magnesium', 'manganese', 'potassium', 'sulphur', 'zinc']


In [33]:
cifar_zeroshot_classifier = text_to_concept.get_zero_shot_classifier(class_names_list,
                                                                     prompts=['a photo contain {} deficiency'])

In [34]:
# cifar_zeroshot_classifier = text_to_concept.get_zero_shot_classifier(cifar_classes,
#                                                                      prompts=['a pixelated photo of a {}'])a type of fish named {}

### Zero-shot performance on CIFAR-10
After loading CIFAR-10, we use `cifar_zeroshot_classifier(x)` to get logits of the classification problem when input $x$ is given.

In [35]:
# cifar = torchvision.datasets.CIFAR10(root='data/',
#                                      download=True,
#                                      train=False,
#                                      transform=std_transform_with_normalization)

In [36]:
# loader = torch.utils.data.DataLoader(cifar, batch_size=16, shuffle=True, num_workers=8)
correct, total = 0, 0
with torch.no_grad():
    for data in tqdm(loader):
        x, y = data[:2]
        x = x.to(device)

        try:
            # Attempt to compute the outputs
            outputs = cifar_zeroshot_classifier(x).detach().cpu()
            _, predicted = outputs.max(1)  # Find the index of the max log-probability
            total += y.size(0)
            correct += predicted.eq(y).sum().item()
        except IndexError as e:
            print(f"Error processing batch: {e}")
            print(f"x.shape: {x.shape}, y: {y}")
            continue  # Skip this batch and continue with the next

  0%|          | 0/97 [00:00<?, ?it/s]

100%|██████████| 97/97 [00:06<00:00, 14.55it/s]


In [37]:
f'Zeroshot Accuracy: {100.*correct/total:.2f}'

'Zeroshot Accuracy: 14.19'