In [1]:
import sys
sys.path.insert(1,'../')

import warnings
warnings.filterwarnings('ignore')

In [2]:
from data.dataloader import ImageNetA, get_dataloader
from data.datautils import PatchAugmenter, AugmenterTPT
from utils.losses import defaultTPT_loss, patch_loss1, patch_loss2, patch_loss3, patch_loss4
from model.custom_clip import get_coop
from copy import deepcopy
import torch.backends.cudnn as cudnn
import torch

In [3]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

args = {
    "imagenet_a_path": "../../Datasets/imagenet-a/",
    "coop_weight_path": "../../model.pth.tar-50",
    "n_aug": 4,
    "n_patches": 16,
    "batch_size": 1,
    "arch": "RN50",
    "device": "cuda:0",
    "learning_rate": 0.005,
    "n_ctx": 4,
    "ctx_init": "",
    "class_token_position": "end",
    "csc": False,
    "run_name": "",
    "augmenter": "PatchAugmenter",
    "loss": "defaultTPT",
    "augmix": True,
    "severity": 1,
    "num_workers": 1,
    "save": False,
    "reduced_size": None,
    "dataset_shuffle": False,
    "save_imgs": False,
    "selection_p_all": 0.1,
    "selection_p_patch": 0.9
}

args = dotdict(args)

In [4]:
def parse_augmenter(args):
    if args.augmenter == "AugmenterTPT":
        args.augmenter = AugmenterTPT(args.n_aug, args.augmix, args.severity)
    elif args.augmenter == "PatchAugmenter":
        print(args.n_patches)
        args.augmenter = PatchAugmenter(
            args.n_aug, args.n_patches, args.augmix, args.severity
        )
    else:
        exit("Augmenter not valid")

parse_augmenter(args)
device = args.device


classnames = ImageNetA.classnames
dataset = ImageNetA(args.imagenet_a_path, transform=args.augmenter)
args.nclasses = len(classnames)
args.classnames = classnames
dataloader = get_dataloader(
    dataset,
    args.batch_size,
    shuffle=args.dataset_shuffle,
    reduced_size=args.reduced_size,
    num_workers=args.num_workers,
)
model = get_coop(args.arch, classnames, args.device, args.n_ctx, args.ctx_init)

print("Use pre-trained soft prompt (CoOp) as initialization")
pretrained_ctx = torch.load(args.coop_weight_path)["state_dict"]["ctx"]

with torch.no_grad():
    model.prompt_learner.ctx.copy_(pretrained_ctx)
    model.prompt_learner.ctx_init_state = pretrained_ctx

for name, param in model.named_parameters():
    if "prompt_learner" not in name:
        param.requires_grad_(False)

model = model.to(args.device)

trainable_param = model.prompt_learner.parameters()
optimizer = torch.optim.AdamW(trainable_param, args.learning_rate)
optim_state = deepcopy(optimizer.state_dict())
scaler = torch.cuda.amp.GradScaler(init_scale=1000)

cudnn.benchmark = True
model.reset_classnames(classnames, args.arch)

16
torch.float32
Random initialization: initializing a generic context
Initial context: "X X X X"
Number of context words (tokens): 4
Use pre-trained soft prompt (CoOp) as initialization


In [5]:
first_batch = next(iter(dataloader))

In [6]:
(imgs, target) = first_batch

In [7]:
images = torch.cat(imgs[1:], dim=0).to(device)  # don't consider view image
orig_img = imgs[1].to(device)
print(images.shape)

torch.Size([85, 3, 224, 224])


In [10]:
print(images[5].shape)


torch.Size([3, 224, 224])


In [11]:
with torch.no_grad():
    with torch.cuda.amp.autocast():
        output = model(images)


In [12]:
if torch.isnan(output).any():
    print("The tensor contains NaN values.")
    
    # Find rows containing NaN
    rows_with_nan = torch.any(torch.isnan(output), dim=1)
    nan_indices = torch.nonzero(rows_with_nan, as_tuple=True)[0]  # Get indices of rows with NaN
    
    for idx in nan_indices:
        print(f"Row {idx} contains NaN values:")
        print(output[idx])
else:
    print("The tensor does not contain NaN values.")



The tensor does not contain NaN values.


In [None]:
print(output.shape)

print(output[6])

In [None]:
print(output.shape)
print(torch.isnan(output).any())
output

In [17]:
def reshape_output_patches(output, args):
    return output.view(-1, args.n_aug + 1, output.shape[-1])

output_reshaped = reshape_output_patches(output, args)
print(torch.isnan(output_reshaped).any())
output_reshaped.shape

tensor(False, device='cuda:0')


torch.Size([17, 5, 200])

In [18]:
mean_output_per_patch = output_reshaped.mean(dim=1)
print(torch.isnan(mean_output_per_patch).any())
mean_output_per_patch.shape

tensor(False, device='cuda:0')


torch.Size([17, 200])

In [21]:
mean_logprob_per_patch = mean_output_per_patch.log_softmax(dim=1)
print(torch.isnan(mean_logprob_per_patch).any())
mean_logprob_per_patch.shape

tensor(False, device='cuda:0')


torch.Size([17, 200])

In [53]:
entropy_per_patch = (mean_logprob_per_patch * torch.exp(mean_logprob_per_patch)).sum(dim=-1)
print(torch.isnan(entropy_per_patch).any())
print(entropy_per_patch)
entropy_per_patch.shape

tensor(False, device='cuda:0')
tensor([-2.2812, -3.4219, -2.6270, -4.1406, -3.7012, -2.7480, -4.1602, -3.0098,
        -3.3828, -2.3730, -3.0996, -2.9980, -3.5996, -3.6172, -3.8203, -3.8145,
        -4.2344], device='cuda:0', dtype=torch.float16)


torch.Size([17])

In [83]:
epsilon = 1e-6
soft_logprob_output = mean_logprob_per_patch * (1/(entropy_per_patch.unsqueeze(dim=1) + epsilon))
soft_logprob_output

tensor([[0.8828, 3.4727, 4.8555,  ..., 4.0391, 5.4297, 2.3711],
        [2.5996, 2.5645, 2.2383,  ..., 3.7461, 1.9814, 1.9727],
        [4.4375, 3.3965, 3.1426,  ..., 4.5977, 3.8652, 2.7500],
        ...,
        [2.5684, 2.0293, 1.8145,  ..., 2.9863, 2.0254, 2.0703],
        [2.1309, 2.1094, 1.9580,  ..., 3.2031, 1.9258, 2.1621],
        [2.0117, 1.3633, 1.5693,  ..., 2.1836, 1.9102, 1.4551]],
       device='cuda:0', dtype=torch.float16)

In [84]:
soft_entropy_loss.shape

torch.Size([17, 200])

In [87]:
logprob_output = soft_entropy_loss.mean(dim=0).log_softmax(dim=0)
print(logprob_output)
logprob_output.shape

tensor([-13.8125, -12.3125, -12.1250, -10.9844, -11.0781, -16.6250,  -7.4570,
        -15.2969, -13.6094, -10.4844, -14.8594, -14.0938, -12.7812, -15.6250,
        -12.8125,  -8.4531, -15.6094, -14.1250, -15.9219,  -9.7344, -11.9688,
        -10.8594, -13.7500, -15.5469, -14.3281,  -9.6875, -10.3906,  -9.2188,
        -15.5781, -15.6094, -10.4219, -12.2656, -11.0000, -12.3594, -15.0000,
        -16.7031, -13.9688, -11.9062, -12.5156,  -9.3281, -11.0938, -11.9375,
        -10.6406, -10.6406, -12.8281,  -9.7031, -11.0938, -14.4375, -14.3594,
        -11.6875, -12.5312, -13.6562, -15.3438, -12.8750, -13.2656, -12.3906,
        -12.2969, -12.3750, -18.0781, -16.3750, -10.1250, -15.7344,  -8.9062,
        -13.6406, -15.9531, -11.4531, -10.7656, -15.2500,  -8.4375, -12.5312,
        -10.9688,  -8.2500,  -7.7539,  -9.2500, -10.5000,  -8.2812,  -9.2188,
        -17.4844, -15.6719, -11.3906,  -7.3945, -19.2031, -10.0781, -15.5469,
         -8.2188, -11.4219, -12.1875, -14.2969, -16.1875, -11.85

torch.Size([200])

In [90]:
entropy_loss = -(logprob_output * torch.exp(logprob_output)).sum(dim=0)
entropy_loss

tensor(1.0596, device='cuda:0', dtype=torch.float16)

In [41]:
a = torch.tensor([[1, 2], [2, 3], [3, 4]])
b = torch.tensor([4,5,6])

print(a)
print(b)
print(a*b.unsqueeze(dim=1))

tensor([[1, 2],
        [2, 3],
        [3, 4]])
tensor([4, 5, 6])
tensor([[ 4,  8],
        [10, 15],
        [18, 24]])


In [63]:
def patch_loss5(outputs, args):
    epsilon = 1e-6

    output_reshaped = reshape_output_patches(output, args)
    mean_output_per_patch = output_reshaped.mean(dim=1)

    mean_logprob_per_patch = mean_output_per_patch.log_softmax(dim=1)
    entropy_per_patch = -(mean_logprob_per_patch * torch.exp(mean_logprob_per_patch)).sum(dim=-1)

    weighted_logprob_per_patch = mean_logprob_per_patch * (1/(entropy_per_patch.unsqueeze(dim=1) + epsilon))
    logprob_output = weighted_logprob_per_patch.mean(dim=0).log_softmax(dim=0)

    entropy_loss = -(logprob_output * torch.exp(logprob_output)).sum(dim=0)

    return entropy_loss