In [73]:
import torch
from CLIP import clip
from PIL import Image
import numpy as np
from torchvision import models, transforms
import torch.nn as nn
import time
import torch.cuda.amp as amp
import torch.backends.cudnn as cudnn
from torch.profiler import profile, record_function, ProfilerActivity
import os
import random

## Model Initialized in two different ways
Traditional ImageNet ResNet50 

In [23]:
# Load ImageNet-pretrained RN50 (standard)
rn50_imagenet = models.resnet50(weights="IMAGENET1K_V1")
print(rn50_imagenet.fc)  # Output: Linear(in_features=2048, out_features=1000)



Linear(in_features=2048, out_features=1000, bias=True)


CLIP's RN50 visual encoder

In [24]:
from CLIP import clip

clip_model, _ = clip.load("RN50", device="cuda")
clip_rn50 = clip_model.visual

# Check the final layers
print(clip_rn50.attnpool) 


100%|███████████████████████████████████████| 244M/244M [01:04<00:00, 3.97MiB/s]


AttentionPool2d(
  (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
  (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
  (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
  (c_proj): Linear(in_features=2048, out_features=1024, bias=True)
)


## Setup zero-shot CLIP

In [8]:
# Load ImageNet class names (1,000 labels)
imagenet_classes = ["tench",
"goldfish",
"great white shark",
"tiger shark",
"hammerhead shark",
"electric ray",
"stingray",
"cock",
"hen",
"ostrich",
"brambling",
"goldfinch",
"house finch",
"junco",
"indigo bunting",
"American robin",
"bulbul",
"jay",
"magpie",
"chickadee",
"American dipper",
"kite",
"bald eagle",
"vulture",
"great grey owl",
"fire salamander",
"smooth newt",
"newt",
"spotted salamander",
"axolotl",
"American bullfrog",
"tree frog",
"tailed frog",
"loggerhead sea turtle",
"leatherback sea turtle",
"mud turtle",
"terrapin",
"box turtle",
"banded gecko",
"green iguana",
"Carolina anole",
"desert grassland whiptail lizard",
"agama",
"frilled-necked lizard",
"alligator lizard",
"Gila monster",
"European green lizard",
"chameleon",
"Komodo dragon",
"Nile crocodile",
"American alligator",
"triceratops",
"worm snake",
"ring-necked snake",
"eastern hog-nosed snake",
"smooth green snake",
"kingsnake",
"garter snake",
"water snake",
"vine snake",
"night snake",
"boa constrictor",
"African rock python",
"Indian cobra",
"green mamba",
"sea snake",
"Saharan horned viper",
"eastern diamondback rattlesnake",
"sidewinder",
"trilobite",
"harvestman",
"scorpion",
"yellow garden spider",
"barn spider",
"European garden spider",
"southern black widow",
"tarantula",
"wolf spider",
"tick",
"centipede",
"black grouse",
"ptarmigan",
"ruffed grouse",
"prairie grouse",
"peacock",
"quail",
"partridge",
"grey parrot",
"macaw",
"sulphur-crested cockatoo",
"lorikeet",
"coucal",
"bee eater",
"hornbill",
"hummingbird",
"jacamar",
"toucan",
"duck",
"red-breasted merganser",
"goose",
"black swan",
"tusker",
"echidna",
"platypus",
"wallaby",
"koala",
"wombat",
"jellyfish",
"sea anemone",
"brain coral",
"flatworm",
"nematode",
"conch",
"snail",
"slug",
"sea slug",
"chiton",
"chambered nautilus",
"Dungeness crab",
"rock crab",
"fiddler crab",
"red king crab",
"American lobster",
"spiny lobster",
"crayfish",
"hermit crab",
"isopod",
"white stork",
"black stork",
"spoonbill",
"flamingo",
"little blue heron",
"great egret",
"bittern",
"crane (bird)",
"limpkin",
"common gallinule",
"American coot",
"bustard",
"ruddy turnstone",
"dunlin",
"common redshank",
"dowitcher",
"oystercatcher",
"pelican",
"king penguin",
"albatross",
"grey whale",
"killer whale",
"dugong",
"sea lion",
"Chihuahua",
"Japanese Chin",
"Maltese",
"Pekingese",
"Shih Tzu",
"King Charles Spaniel",
"Papillon",
"toy terrier",
"Rhodesian Ridgeback",
"Afghan Hound",
"Basset Hound",
"Beagle",
"Bloodhound",
"Bluetick Coonhound",
"Black and Tan Coonhound",
"Treeing Walker Coonhound",
"English foxhound",
"Redbone Coonhound",
"borzoi",
"Irish Wolfhound",
"Italian Greyhound",
"Whippet",
"Ibizan Hound",
"Norwegian Elkhound",
"Otterhound",
"Saluki",
"Scottish Deerhound",
"Weimaraner",
"Staffordshire Bull Terrier",
"American Staffordshire Terrier",
"Bedlington Terrier",
"Border Terrier",
"Kerry Blue Terrier",
"Irish Terrier",
"Norfolk Terrier",
"Norwich Terrier",
"Yorkshire Terrier",
"Wire Fox Terrier",
"Lakeland Terrier",
"Sealyham Terrier",
"Airedale Terrier",
"Cairn Terrier",
"Australian Terrier",
"Dandie Dinmont Terrier",
"Boston Terrier",
"Miniature Schnauzer",
"Giant Schnauzer",
"Standard Schnauzer",
"Scottish Terrier",
"Tibetan Terrier",
"Australian Silky Terrier",
"Soft-coated Wheaten Terrier",
"West Highland White Terrier",
"Lhasa Apso",
"Flat-Coated Retriever",
"Curly-coated Retriever",
"Golden Retriever",
"Labrador Retriever",
"Chesapeake Bay Retriever",
"German Shorthaired Pointer",
"Vizsla",
"English Setter",
"Irish Setter",
"Gordon Setter",
"Brittany Spaniel",
"Clumber Spaniel",
"English Springer Spaniel",
"Welsh Springer Spaniel",
"Cocker Spaniels",
"Sussex Spaniel",
"Irish Water Spaniel",
"Kuvasz",
"Schipperke",
"Groenendael",
"Malinois",
"Briard",
"Australian Kelpie",
"Komondor",
"Old English Sheepdog",
"Shetland Sheepdog",
"collie",
"Border Collie",
"Bouvier des Flandres",
"Rottweiler",
"German Shepherd Dog",
"Dobermann",
"Miniature Pinscher",
"Greater Swiss Mountain Dog",
"Bernese Mountain Dog",
"Appenzeller Sennenhund",
"Entlebucher Sennenhund",
"Boxer",
"Bullmastiff",
"Tibetan Mastiff",
"French Bulldog",
"Great Dane",
"St. Bernard",
"husky",
"Alaskan Malamute",
"Siberian Husky",
"Dalmatian",
"Affenpinscher",
"Basenji",
"pug",
"Leonberger",
"Newfoundland",
"Pyrenean Mountain Dog",
"Samoyed",
"Pomeranian",
"Chow Chow",
"Keeshond",
"Griffon Bruxellois",
"Pembroke Welsh Corgi",
"Cardigan Welsh Corgi",
"Toy Poodle",
"Miniature Poodle",
"Standard Poodle",
"Mexican hairless dog",
"grey wolf",
"Alaskan tundra wolf",
"red wolf",
"coyote",
"dingo",
"dhole",
"African wild dog",
"hyena",
"red fox",
"kit fox",
"Arctic fox",
"grey fox",
"tabby cat",
"tiger cat",
"Persian cat",
"Siamese cat",
"Egyptian Mau",
"cougar",
"lynx",
"leopard",
"snow leopard",
"jaguar",
"lion",
"tiger",
"cheetah",
"brown bear",
"American black bear",
"polar bear",
"sloth bear",
"mongoose",
"meerkat",
"tiger beetle",
"ladybug",
"ground beetle",
"longhorn beetle",
"leaf beetle",
"dung beetle",
"rhinoceros beetle",
"weevil",
"fly",
"bee",
"ant",
"grasshopper",
"cricket",
"stick insect",
"cockroach",
"mantis",
"cicada",
"leafhopper",
"lacewing",
"dragonfly",
"damselfly",
"red admiral",
"ringlet",
"monarch butterfly",
"small white",
"sulphur butterfly",
"gossamer-winged butterfly",
"starfish",
"sea urchin",
"sea cucumber",
"cottontail rabbit",
"hare",
"Angora rabbit",
"hamster",
"porcupine",
"fox squirrel",
"marmot",
"beaver",
"guinea pig",
"common sorrel",
"zebra",
"pig",
"wild boar",
"warthog",
"hippopotamus",
"ox",
"water buffalo",
"bison",
"ram",
"bighorn sheep",
"Alpine ibex",
"hartebeest",
"impala",
"gazelle",
"dromedary",
"llama",
"weasel",
"mink",
"European polecat",
"black-footed ferret",
"otter",
"skunk",
"badger",
"armadillo",
"three-toed sloth",
"orangutan",
"gorilla",
"chimpanzee",
"gibbon",
"siamang",
"guenon",
"patas monkey",
"baboon",
"macaque",
"langur",
"black-and-white colobus",
"proboscis monkey",
"marmoset",
"white-headed capuchin",
"howler monkey",
"titi",
"Geoffroy's spider monkey",
"common squirrel monkey",
"ring-tailed lemur",
"indri",
"Asian elephant",
"African bush elephant",
"red panda",
"giant panda",
"snoek",
"eel",
"coho salmon",
"rock beauty",
"clownfish",
"sturgeon",
"garfish",
"lionfish",
"pufferfish",
"abacus",
"abaya",
"academic gown",
"accordion",
"acoustic guitar",
"aircraft carrier",
"airliner",
"airship",
"altar",
"ambulance",
"amphibious vehicle",
"analog clock",
"apiary",
"apron",
"waste container",
"assault rifle",
"backpack",
"bakery",
"balance beam",
"balloon",
"ballpoint pen",
"Band-Aid",
"banjo",
"baluster",
"barbell",
"barber chair",
"barbershop",
"barn",
"barometer",
"barrel",
"wheelbarrow",
"baseball",
"basketball",
"bassinet",
"bassoon",
"swimming cap",
"bath towel",
"bathtub",
"station wagon",
"lighthouse",
"beaker",
"military cap",
"beer bottle",
"beer glass",
"bell-cot",
"bib",
"tandem bicycle",
"bikini",
"ring binder",
"binoculars",
"birdhouse",
"boathouse",
"bobsleigh",
"bolo tie",
"poke bonnet",
"bookcase",
"bookstore",
"bottle cap",
"bow",
"bow tie",
"brass",
"bra",
"breakwater",
"breastplate",
"broom",
"bucket",
"buckle",
"bulletproof vest",
"high-speed train",
"butcher shop",
"taxicab",
"cauldron",
"candle",
"cannon",
"canoe",
"can opener",
"cardigan",
"car mirror",
"carousel",
"tool kit",
"carton",
"car wheel",
"automated teller machine",
"cassette",
"cassette player",
"castle",
"catamaran",
"CD player",
"cello",
"mobile phone",
"chain",
"chain-link fence",
"chain mail",
"chainsaw",
"chest",
"chiffonier",
"chime",
"china cabinet",
"Christmas stocking",
"church",
"movie theater",
"cleaver",
"cliff dwelling",
"cloak",
"clogs",
"cocktail shaker",
"coffee mug",
"coffeemaker",
"coil",
"combination lock",
"computer keyboard",
"confectionery store",
"container ship",
"convertible",
"corkscrew",
"cornet",
"cowboy boot",
"cowboy hat",
"cradle",
"crane (machine)",
"crash helmet",
"crate",
"infant bed",
"Crock Pot",
"croquet ball",
"crutch",
"cuirass",
"dam",
"desk",
"desktop computer",
"rotary dial telephone",
"diaper",
"digital clock",
"digital watch",
"dining table",
"dishcloth",
"dishwasher",
"disc brake",
"dock",
"dog sled",
"dome",
"doormat",
"drilling rig",
"drum",
"drumstick",
"dumbbell",
"Dutch oven",
"electric fan",
"electric guitar",
"electric locomotive",
"entertainment center",
"envelope",
"espresso machine",
"face powder",
"feather boa",
"filing cabinet",
"fireboat",
"fire engine",
"fire screen sheet",
"flagpole",
"flute",
"folding chair",
"football helmet",
"forklift",
"fountain",
"fountain pen",
"four-poster bed",
"freight car",
"French horn",
"frying pan",
"fur coat",
"garbage truck",
"gas mask",
"gas pump",
"goblet",
"go-kart",
"golf ball",
"golf cart",
"gondola",
"gong",
"gown",
"grand piano",
"greenhouse",
"grille",
"grocery store",
"guillotine",
"barrette",
"hair spray",
"half-track",
"hammer",
"hamper",
"hair dryer",
"hand-held computer",
"handkerchief",
"hard disk drive",
"harmonica",
"harp",
"harvester",
"hatchet",
"holster",
"home theater",
"honeycomb",
"hook",
"hoop skirt",
"horizontal bar",
"horse-drawn vehicle",
"hourglass",
"iPod",
"clothes iron",
"jack-o'-lantern",
"jeans",
"jeep",
"T-shirt",
"jigsaw puzzle",
"pulled rickshaw",
"joystick",
"kimono",
"knee pad",
"knot",
"lab coat",
"ladle",
"lampshade",
"laptop computer",
"lawn mower",
"lens cap",
"paper knife",
"library",
"lifeboat",
"lighter",
"limousine",
"ocean liner",
"lipstick",
"slip-on shoe",
"lotion",
"speaker",
"loupe",
"sawmill",
"magnetic compass",
"mail bag",
"mailbox",
"tights",
"tank suit",
"manhole cover",
"maraca",
"marimba",
"mask",
"match",
"maypole",
"maze",
"measuring cup",
"medicine chest",
"megalith",
"microphone",
"microwave oven",
"military uniform",
"milk can",
"minibus",
"miniskirt",
"minivan",
"missile",
"mitten",
"mixing bowl",
"mobile home",
"Model T",
"modem",
"monastery",
"monitor",
"moped",
"mortar",
"square academic cap",
"mosque",
"mosquito net",
"scooter",
"mountain bike",
"tent",
"computer mouse",
"mousetrap",
"moving van",
"muzzle",
"nail",
"neck brace",
"necklace",
"nipple",
"notebook computer",
"obelisk",
"oboe",
"ocarina",
"odometer",
"oil filter",
"organ",
"oscilloscope",
"overskirt",
"bullock cart",
"oxygen mask",
"packet",
"paddle",
"paddle wheel",
"padlock",
"paintbrush",
"pajamas",
"palace",
"pan flute",
"paper towel",
"parachute",
"parallel bars",
"park bench",
"parking meter",
"passenger car",
"patio",
"payphone",
"pedestal",
"pencil case",
"pencil sharpener",
"perfume",
"Petri dish",
"photocopier",
"plectrum",
"Pickelhaube",
"picket fence",
"pickup truck",
"pier",
"piggy bank",
"pill bottle",
"pillow",
"ping-pong ball",
"pinwheel",
"pirate ship",
"pitcher",
"hand plane",
"planetarium",
"plastic bag",
"plate rack",
"plow",
"plunger",
"Polaroid camera",
"pole",
"police van",
"poncho",
"billiard table",
"soda bottle",
"pot",
"potter's wheel",
"power drill",
"prayer rug",
"printer",
"prison",
"projectile",
"projector",
"hockey puck",
"punching bag",
"purse",
"quill",
"quilt",
"race car",
"racket",
"radiator",
"radio",
"radio telescope",
"rain barrel",
"recreational vehicle",
"reel",
"reflex camera",
"refrigerator",
"remote control",
"restaurant",
"revolver",
"rifle",
"rocking chair",
"rotisserie",
"eraser",
"rugby ball",
"ruler",
"running shoe",
"safe",
"safety pin",
"salt shaker",
"sandal",
"sarong",
"saxophone",
"scabbard",
"weighing scale",
"school bus",
"schooner",
"scoreboard",
"CRT screen",
"screw",
"screwdriver",
"seat belt",
"sewing machine",
"shield",
"shoe store",
"shoji",
"shopping basket",
"shopping cart",
"shovel",
"shower cap",
"shower curtain",
"ski",
"ski mask",
"sleeping bag",
"slide rule",
"sliding door",
"slot machine",
"snorkel",
"snowmobile",
"snowplow",
"soap dispenser",
"soccer ball",
"sock",
"solar thermal collector",
"sombrero",
"soup bowl",
"space bar",
"space heater",
"space shuttle",
"spatula",
"motorboat",
"spider web",
"spindle",
"sports car",
"spotlight",
"stage",
"steam locomotive",
"through arch bridge",
"steel drum",
"stethoscope",
"scarf",
"stone wall",
"stopwatch",
"stove",
"strainer",
"tram",
"stretcher",
"couch",
"stupa",
"submarine",
"suit",
"sundial",
"sunglass",
"sunglasses",
"sunscreen",
"suspension bridge",
"mop",
"sweatshirt",
"swimsuit",
"swing",
"switch",
"syringe",
"table lamp",
"tank",
"tape player",
"teapot",
"teddy bear",
"television",
"tennis ball",
"thatched roof",
"front curtain",
"thimble",
"threshing machine",
"throne",
"tile roof",
"toaster",
"tobacco shop",
"toilet seat",
"torch",
"totem pole",
"tow truck",
"toy store",
"tractor",
"semi-trailer truck",
"tray",
"trench coat",
"tricycle",
"trimaran",
"tripod",
"triumphal arch",
"trolleybus",
"trombone",
"tub",
"turnstile",
"typewriter keyboard",
"umbrella",
"unicycle",
"upright piano",
"vacuum cleaner",
"vase",
"vault",
"velvet",
"vending machine",
"vestment",
"viaduct",
"violin",
"volleyball",
"waffle iron",
"wall clock",
"wallet",
"wardrobe",
"military aircraft",
"sink",
"washing machine",
"water bottle",
"water jug",
"water tower",
"whiskey jug",
"whistle",
"wig",
"window screen",
"window shade",
"Windsor tie",
"wine bottle",
"wing",
"wok",
"wooden spoon",
"wool",
"split-rail fence",
"shipwreck",
"yawl",
"yurt",
"website",
"comic book",
"crossword",
"traffic sign",
"traffic light",
"dust jacket",
"menu",
"plate",
"guacamole",
"consomme",
"hot pot",
"trifle",
"ice cream",
"ice pop",
"baguette",
"bagel",
"pretzel",
"cheeseburger",
"hot dog",
"mashed potato",
"cabbage",
"broccoli",
"cauliflower",
"zucchini",
"spaghetti squash",
"acorn squash",
"butternut squash",
"cucumber",
"artichoke",
"bell pepper",
"cardoon",
"mushroom",
"Granny Smith",
"strawberry",
"orange",
"lemon",
"fig",
"pineapple",
"banana",
"jackfruit",
"custard apple",
"pomegranate",
"hay",
"carbonara",
"chocolate syrup",
"dough",
"meatloaf",
"pizza",
"pot pie",
"burrito",
"red wine",
"espresso",
"cup",
"eggnog",
"alp",
"bubble",
"cliff",
"coral reef",
"geyser",
"lakeshore",
"promontory",
"shoal",
"seashore",
"valley",
"volcano",
"baseball player",
"bridegroom",
"scuba diver",
"rapeseed",
"daisy",
"yellow lady's slipper",
"corn",
"acorn",
"rose hip",
"horse chestnut seed",
"coral fungus",
"agaric",
"gyromitra",
"stinkhorn mushroom",
"earth star",
"hen-of-the-woods",
"bolete",
"ear of corn",
"toilet paper"] 

In [27]:
class CLIPClassifier:

    def __init__(self, imagenet_classes, model_type="transformer"):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.imagenet_classes = imagenet_classes
        self.text_inputs = torch.cat([
            clip.tokenize(f"a photo of a {c}") for c in imagenet_classes
        ]).to(self.device)
        if model_type == "transformer":
            self.model, self.preprocess = clip.load("ViT-B/16", self.device)
        elif model_type == "rn50":
            self.model, self.preprocess = clip.load("RN50", self.device)
        else:
            raise ValueError("model_type must be 'transformer' or 'rn50'")
        self.model_type = model_type
    
    def classify_image(self, image_path, top_k = 5):
        image = self.preprocess(Image.open(image_path)).unsqueeze(0).to(self.device)
        with torch.no_grad():
            image_features = self.model.encode_image(image)
            text_features = self.model.encode_text(self.text_inputs)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        logit_scale = self.model.logit_scale.exp()
        logits = logit_scale * (image_features @ text_features.T)
        probs = logits.softmax(dim=-1)

        # Get top k predictions
        probs_np = probs.detach().cpu().numpy().flatten()
        top_indices = np.argsort(probs_np)[-top_k:][::-1]
        
        # Print results
        for idx in top_indices:
            print(f"{self.imagenet_classes[idx]}: {probs_np[idx]:.4f}")
        
        return probs_np

In [18]:
classifier = CLIPClassifier(imagenet_classes, model_type="transformer")

probabilities = classifier.classify_image("/home2/shaon/computer_vision/Q3/test_image/dog.jpg") 
# print(probabilities)

100%|███████████████████████████████████████| 335M/335M [01:54<00:00, 3.07MiB/s]


Golden Retriever: 0.6646
Labrador Retriever: 0.1155
Irish Setter: 0.0658
Afghan Hound: 0.0164
English Setter: 0.0136


In [19]:
classifier = CLIPClassifier(imagenet_classes, model_type="transformer")

probabilities = classifier.classify_image("/home2/shaon/computer_vision/Q3/test_image/cat.jpg") 

tabby cat: 0.3586
Siamese cat: 0.2625
tiger cat: 0.1449
Persian cat: 0.0240
Egyptian Mau: 0.0165


## Picked Imagenette is a subset of 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute).

n01440764--->tench
n02102040--->English springer
n02979186--->cassette player
n03000684--->chain saw
n03028079--->church
n03394916--->French horn
n03417042--->garbage truck
n03425413--->gas pump
n03445777--->golf ball
n03888257--->parachute



In [61]:
# 1. ResNet50 Classifier (ImageNet-pretrained)
class RN50Classifier:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.model = models.resnet50(weights='IMAGENET1K_V1')
        num_features = self.model.fc.in_features
        self.model.fc = nn.Linear(num_features, 10)
        self.model.to(self.device)


        self.model.eval()
        
        # ImageNet normalization
        self.preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        # Load ImageNet class names
        self.imagenet_classes = [
            'tench', 'English springer', 'cassette player', 'chain saw', 
            'church', 'French horn', 'garbage truck', 'gas pump', 
            'golf ball', 'parachute'
        ]

    def classify_image(self, image_path, top_k=5):

        try:
            image = Image.open(image_path)
            # Check if image is grayscale (1 channel)
            if image.mode == 'L':
                print("Grayscale image detected, skipping classification")
                return None
                
            image = self.preprocess(image).unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                outputs = self.model(image)
                probs = torch.nn.functional.softmax(outputs, dim=1)[0]
            
            probs_np = probs.cpu().numpy()
            top_indices = np.argsort(probs_np)[-top_k:][::-1]
            
            for idx in top_indices:
                print(f"{self.imagenet_classes[idx]}: {probs_np[idx]:.4f}")
            
            return probs_np
        except Exception as e:
            print(f"Error processing image: {str(e)}")
            return None

In [67]:
class ImagenetteLoader:
    def __init__(self, root_path, val_only=True):
        self.class_map = {
            'n01440764': 'tench',
            'n02102040': 'English springer',
            'n02979186': 'cassette player',
            'n03000684': 'chain saw',
            'n03028079': 'church',
            'n03394916': 'French horn',
            'n03417042': 'garbage truck',
            'n03425413': 'gas pump',
            'n03445777': 'golf ball',
            'n03888257': 'parachute'
        }
        self.samples = []
        
        for class_id in self.class_map:
            val_path = os.path.join(root_path, 'train', class_id)
            for img_name in os.listdir(val_path):
                self.samples.append((
                    os.path.join(val_path, img_name),
                    self.class_map[class_id]
                ))
    
    def get_random_samples(self, n=None):
        if n is None:
            return self.samples
        return random.sample(self.samples, n)


In [80]:
#3. Comparison Test
def compare_models():

    class_map = {
        'n01440764': 'tench',
        'n02102040': 'English springer',
        'n02979186': 'cassette player',
        'n03000684': 'chain saw',
        'n03028079': 'church',
        'n03394916': 'French horn',
        'n03417042': 'garbage truck',
        'n03425413': 'gas pump',
        'n03445777': 'golf ball',
        'n03888257': 'parachute'
    }
    
    # Initialize models


    clip_model = CLIPClassifier(list(class_map.values()), model_type="transformer")
    rn50_model = RN50Classifier()
    loader = ImagenetteLoader('/ssd_scratch/cvit/shaon/cv_5/Q3/imagenette2')
    
    # Find cases where CLIP works but RN50 fails
    print("Finding images where CLIP succeeds and RN50 fails...")
    found = 0
    for img_path, true_label in loader.get_random_samples(50):
        print(f"\nTesting: {true_label} ({img_path})")
        
        # CLIP prediction
        print("CLIP Predictions:")
        clip_probs = clip_model.classify_image(img_path)
        clip_top5 = set(np.argsort(clip_probs)[-5:][::-1])
        
        # RN50 prediction
        print("\nRN50 Predictions:")
        rn50_probs = rn50_model.classify_image(img_path)
        rn50_top5 = set(np.argsort(rn50_probs)[-5:][::-1])
        
        # Check conditions
        correct_label_idx = list(class_map.values()).index(true_label)
        
        if (correct_label_idx in clip_top5) and (correct_label_idx not in rn50_top5):
            print(f"\nFOUND CASE: CLIP works, RN50 fails - {img_path}")
            found += 1
            if found >= 2:
                break

    # Find cases where RN50 works but CLIP fails (similar logic)
    # 
    print("Finding images where RN50 works and CLIP fails...")
    found = 0
    for img_path, true_label in loader.get_random_samples(50):
        print(f"\nTesting: {true_label} ({img_path})")
        
        # CLIP prediction
        print("CLIP Predictions:")
        clip_probs = clip_model.classify_image(img_path)
        clip_top5 = set(np.argsort(clip_probs)[-5:][::-1])
        
        # RN50 prediction
        print("\nRN50 Predictions:")
        rn50_probs = rn50_model.classify_image(img_path)
        rn50_top5 = set(np.argsort(rn50_probs)[-5:][::-1])
        
        # Check conditions
        correct_label_idx = list(class_map.values()).index(true_label)
        
        if (correct_label_idx in rn50_top5) and (correct_label_idx not in clip_top5):
            print(f"\nFOUND CASE: RN50 works, CLIP fails - {img_path}")
            found += 1
            if found >= 1:
                break

if __name__ == "__main__":
    compare_models()

Finding images where CLIP succeeds and RN50 fails...

Testing: garbage truck (/ssd_scratch/cvit/shaon/cv_5/Q3/imagenette2/train/n03417042/n03417042_28569.JPEG)
CLIP Predictions:
garbage truck: 0.9995
chain saw: 0.0003
gas pump: 0.0003
English springer: 0.0000
church: 0.0000

RN50 Predictions:
tench: 0.2032
French horn: 0.1540
chain saw: 0.1084
golf ball: 0.0972
English springer: 0.0905

FOUND CASE: CLIP works, RN50 fails - /ssd_scratch/cvit/shaon/cv_5/Q3/imagenette2/train/n03417042/n03417042_28569.JPEG

Testing: English springer (/ssd_scratch/cvit/shaon/cv_5/Q3/imagenette2/train/n02102040/n02102040_875.JPEG)
CLIP Predictions:
English springer: 1.0000
garbage truck: 0.0000
golf ball: 0.0000
cassette player: 0.0000
gas pump: 0.0000

RN50 Predictions:
tench: 0.2977
French horn: 0.1192
church: 0.1085
parachute: 0.0958
chain saw: 0.0890

FOUND CASE: CLIP works, RN50 fails - /ssd_scratch/cvit/shaon/cv_5/Q3/imagenette2/train/n02102040/n02102040_875.JPEG
Finding images where RN50 works and CLI

In [79]:
image_path = "/home2/shaon/computer_vision/Q3/test_image/test.jpeg"

class_map = {
    'n01440764': 'tench',
    'n02102040': 'English springer',
    'n02979186': 'cassette player',
    'n03000684': 'chain saw',
    'n03028079': 'church',
    'n03394916': 'French horn',
    'n03417042': 'garbage truck',
    'n03425413': 'gas pump',
    'n03445777': 'golf ball',
    'n03888257': 'parachute'
}
# with clip model
clip_model = CLIPClassifier(list(class_map.values()), model_type="transformer")

print("CLIP Predictions:")
clip_model.classify_image(image_path)

#with rn50 model

rn50_model = RN50Classifier()
print("RN50 Predictions:")
rn50_probs= rn50_model.classify_image(image_path)




CLIP Predictions:
cassette player: 0.9980
French horn: 0.0007
golf ball: 0.0006
tench: 0.0002
garbage truck: 0.0002
RN50 Predictions:
English springer: 0.1289
French horn: 0.1260
cassette player: 0.1138
golf ball: 0.1104
tench: 0.1089


In [76]:
class CLIPClassifier_new:
    def __init__(self, imagenet_classes, model_type="transformer", use_fp16=False):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.imagenet_classes = imagenet_classes
        self.text_inputs = torch.cat([
            clip.tokenize(f"a photo of a {c}") for c in imagenet_classes
        ]).to(self.device)
        if model_type == "transformer":
            self.model, self.preprocess = clip.load("ViT-B/16", self.device)
        elif model_type == "rn50":
            self.model, self.preprocess = clip.load("RN50", self.device)
        else:
            raise ValueError("model_type must be 'transformer' or 'rn50'")
        self.model_type = model_type
        self.use_fp16 = use_fp16
        if use_fp16:
            # Only convert the visual part of the model to FP16
            self.model.visual = self.model.visual.half()
        self.scaler = amp.GradScaler(enabled=use_fp16)
    
    def benchmark_forward_pass(self, image_path, num_runs=100):
        image = self.preprocess(Image.open(image_path)).unsqueeze(0).to(self.device)
        if self.use_fp16:
            image = image.half()
        
        # Warmup
        for _ in range(10):
            with torch.no_grad():
                _ = self.model.encode_image(image)
        
        # Benchmark
        times = []
        for _ in range(num_runs):
            start_time = time.time()
            with torch.no_grad():
                _ = self.model.encode_image(image)
            torch.cuda.synchronize()
            end_time = time.time()
            times.append(end_time - start_time)
        
        return np.mean(times), np.std(times)
    
    def classify_image(self, image_path, top_k=5):
        try:
            image = Image.open(image_path)
            if image.mode == 'L':
                print(f"Grayscale image detected, skipping: {image_path}")
                return None
                
            image = self.preprocess(image).unsqueeze(0).to(self.device)
            if self.use_fp16:
                image = image.half()
            
            with torch.no_grad():
                image_features = self.model.encode_image(image)
                text_features = self.model.encode_text(self.text_inputs)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            logit_scale = self.model.logit_scale.exp()
            logits = logit_scale * (image_features @ text_features.T)
            probs = logits.softmax(dim=-1)

            probs_np = probs.detach().cpu().numpy().flatten()
            top_indices = np.argsort(probs_np)[-top_k:][::-1]
            
            for idx in top_indices:
                print(f"{self.imagenet_classes[idx]}: {probs_np[idx]:.4f}")
            
            return probs_np
        except Exception as e:
            print(f"Error processing image {image_path}: {str(e)}")
            return None

def compare_fp16_fp32():
    # Initialize models
    clip_fp32 = CLIPClassifier_new(imagenet_classes, model_type="rn50", use_fp16=False)
    clip_fp16 = CLIPClassifier_new(imagenet_classes, model_type="rn50", use_fp16=True)
    
    # Load some test images
    loader = ImagenetteLoader('/ssd_scratch/cvit/shaon/cv_5/Q3/imagenette2')
    test_images = loader.get_random_samples(5)
    
    # Benchmark performance
    print("\nBenchmarking FP32 model...")
    fp32_mean, fp32_std = clip_fp32.benchmark_forward_pass(test_images[0][0])
    print(f"FP32 Mean time: {fp32_mean*1000:.2f}ms ± {fp32_std*1000:.2f}ms")
    
    print("\nBenchmarking FP16 model...")
    fp16_mean, fp16_std = clip_fp16.benchmark_forward_pass(test_images[0][0])
    print(f"FP16 Mean time: {fp16_mean*1000:.2f}ms ± {fp16_std*1000:.2f}ms")
    
    print(f"\nSpeedup: {fp32_mean/fp16_mean:.2f}x")
    
    # Compare outputs
    print("\nComparing outputs for 5 test images:")
    for img_path, true_label in test_images:
        print(f"\nImage: {true_label}")
        print("FP32 predictions:")
        fp32_probs = clip_fp32.classify_image(img_path)
        print("\nFP16 predictions:")
        fp16_probs = clip_fp16.classify_image(img_path)
        
        if fp32_probs is not None and fp16_probs is not None:
            diff = np.abs(fp32_probs - fp16_probs)
            print(f"Max difference: {np.max(diff):.6f}")
            print(f"Mean difference: {np.mean(diff):.6f}")
    
    # Memory usage comparison
    print("\nMemory usage comparison:")
    image = clip_fp32.preprocess(Image.open(test_images[0][0])).unsqueeze(0).to(clip_fp32.device)
    
    with profile(activities=[ProfilerActivity.CUDA], profile_memory=True) as prof:
        with torch.no_grad():
            _ = clip_fp32.model.encode_image(image)
    print("\nFP32 memory usage:")
    print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))
    
    with profile(activities=[ProfilerActivity.CUDA], profile_memory=True) as prof:
        with torch.no_grad():
            _ = clip_fp16.model.encode_image(image.half())
    print("\nFP16 memory usage:")
    print(prof.key_averages().table(sort_by="cuda_memory_usage", row_limit=10))

# Run the comparison
compare_fp16_fp32()


Benchmarking FP32 model...
FP32 Mean time: 8.90ms ± 0.14ms

Benchmarking FP16 model...
FP16 Mean time: 9.57ms ± 0.09ms

Speedup: 0.93x

Comparing outputs for 5 test images:

Image: cassette player
FP32 predictions:
tape player: 0.4407
radio: 0.3433
cassette player: 0.1115
cassette: 0.0340
CD player: 0.0096

FP16 predictions:
tape player: 0.4343
radio: 0.3489
cassette player: 0.1115
cassette: 0.0335
CD player: 0.0094
Max difference: 0.006348
Mean difference: 0.000014

Image: church
FP32 predictions:
Grayscale image detected, skipping: /ssd_scratch/cvit/shaon/cv_5/Q3/imagenette2/train/n03028079/n03028079_14278.JPEG

FP16 predictions:
Grayscale image detected, skipping: /ssd_scratch/cvit/shaon/cv_5/Q3/imagenette2/train/n03028079/n03028079_14278.JPEG

Image: gas pump
FP32 predictions:
gas pump: 0.7847
vending machine: 0.0486
payphone: 0.0471
coffeemaker: 0.0090
carton: 0.0077

FP16 predictions:
gas pump: 0.7812
vending machine: 0.0492
payphone: 0.0477
coffeemaker: 0.0090
carton: 0.0077
Ma

STAGE:2025-04-20 20:29:14 12100:12100 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2025-04-20 20:29:14 12100:12100 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2025-04-20 20:29:14 12100:12100 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
STAGE:2025-04-20 20:29:14 12100:12100 ActivityProfilerController.cpp:311] Completed Stage: Warm Up
STAGE:2025-04-20 20:29:14 12100:12100 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2025-04-20 20:29:14 12100:12100 ActivityProfilerController.cpp:321] Completed Stage: Post Processing
