In [None]:
import argparse
import os
from pprint import pprint

from matplotlib import pyplot as plt
import torch
from torch import nn

from main import parse_option
from learner import Learner

In [None]:
@torch.no_grad()
def visualize_prompt(prompt_learner: nn.Module):
    """Visualizes the prompt."""
    fake_img = torch.ones(1, 3, 224, 224)
    prompted_img = prompt_learner(fake_img).cpu()
    prompted_img = torch.clamp(prompted_img, 0, 1)
    print("Visualizing prompt...")
    plt.imshow(prompted_img[0].permute(1, 2, 0).numpy())
    

def parse_option(str_opt):
    parser = argparse.ArgumentParser("Visual Prompting for CLIP")

    parser.add_argument("--print_freq", type=int, default=10, help="print frequency")
    parser.add_argument("--save_freq", type=int, default=50, help="save frequency")
    parser.add_argument("--batch_size", type=int, default=128, help="batch_size")
    parser.add_argument(
        "--num_workers", type=int, default=16, help="num of workers to use"
    )
    parser.add_argument(
        "--epochs", type=int, default=1000, help="number of training epochs"
    )
    parser.add_argument(
        "--square_size",
        type=int,
        default=8,
        help="size of each square in checkboard prompt",
    )
    # optimization
    parser.add_argument("--optim", type=str, default="sgd", help="optimizer to use")
    parser.add_argument("--learning_rate", type=float, default=40, help="learning rate")
    parser.add_argument("--weight_decay", type=float, default=0, help="weight decay")
    parser.add_argument(
        "--warmup", type=int, default=1000, help="number of steps to warmup for"
    )
    parser.add_argument("--momentum", type=float, default=0.9, help="momentum")
    parser.add_argument("--patience", type=int, default=1000)

    # model
    parser.add_argument("--model", type=str, default="clip")
    parser.add_argument("--arch", type=str, default="ViT-B/32")
    parser.add_argument(
        "--method",
        type=str,
        default="padding",
        choices=["padding", "random_patch", "fixed_patch", "checkers"],
        help="choose visual prompting method",
    )
    parser.add_argument(
        "--prompt_size", type=int, default=30, help="size for visual prompts"
    )
    parser.add_argument(
        "--text_prompt_template",
        type=str,
        default="This is a photo of a {}",
    )
    parser.add_argument(
        "--visualize_prompt",
        action="store_true",
        help="visualize the (randomly initialized) prompt and save it to a file for debugging",
    )

    # dataset
    parser.add_argument("--root", type=str, default="./data", help="dataset")
    parser.add_argument("--dataset", type=str, default="cifar100", help="dataset")
    parser.add_argument("--image_size", type=int, default=224, help="image size")
    parser.add_argument(
        "--test_noise",
        default=False,
        action="store_true",
        help="whether to add noise to the test images",
    )

    # other
    parser.add_argument(
        "--seed", type=int, default=0, help="seed for initializing training"
    )
    parser.add_argument(
        "--model_dir", type=str, default="./save/models", help="path to save models"
    )
    parser.add_argument(
        "--image_dir", type=str, default="./save/images", help="path to save images"
    )
    parser.add_argument("--filename", type=str, default=None, help="filename to save")
    parser.add_argument("--trial", type=int, default=1, help="number of trials")
    parser.add_argument(
        "--resume", type=str, default=None, help="path to resume from checkpoint"
    )
    parser.add_argument(
        "--evaluate",
        default=False,
        action="store_true",
        help="evaluate model test set",
    )
    parser.add_argument("--gpu", type=int, default=None, help="gpu to use")
    parser.add_argument(
        "--use_wandb",
        default=False,
        action="store_true",
        help="whether to use wandb",
    )
    parser.add_argument("--verbose", action="store_true")

    args = parser.parse_args(str_opt.split())

    args.num_workers = min(args.num_workers, os.cpu_count())

    args.filename = "{}_{}_{}_{}_{}_{}_lr_{}_decay_{}_bsz_{}_warmup_{}_trial_{}".format(
        args.method,
        args.prompt_size,
        args.dataset,
        args.model,
        args.arch,
        args.optim,
        args.learning_rate,
        args.weight_decay,
        args.batch_size,
        args.warmup,
        args.trial,
    )

    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    args.model_folder = os.path.join(args.model_dir, args.filename)
    if not os.path.isdir(args.model_folder):
        os.makedirs(args.model_folder)

    return args

In [None]:
opt_str = "--dataset cifar10 --epochs 10 --method checkers --prompt_size 2 --num_workers 3 --print_freq 50 --patience 5 --visualize"
args = parse_option(opt_str)

In [None]:
learn = Learner(args)

In [None]:
# train
learn.run()

In [None]:
test_acc = learn.evaluate("test")
print(f"test accuracy: {test_acc}")

In [None]:
#prompt = CheckersPrompt(224, 2)
#visualize_prompt(prompt)

In [None]:
#dummy_in = torch.ones(1, 3, 224, 224)
#%timeit x = prompt(dummy_in)