In [1]:
import torch
import wandb
import os
import math
from color.CEConv.models.resnet_variational import ResNet18 as ResNet18_partial

from color.generate_data import generate_102flower_data
import matplotlib.pyplot as plt
import yaml
from easydict import EasyDict
import numpy as np
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
from tqdm import tqdm
import seaborn as sns
sns.set()
plt.rc('axes', labelsize=15)   # x,y축 label 폰트 크기
plt.rc('xtick', labelsize=15)
plt.rc('ytick', labelsize=15)

  from .autonotebook import tqdm as notebook_tqdm


In [20]:
run_path = "kim-hyunsu/partial_equiv/0vh427nx"
file = "/mnt/home/yegonkim/home/partial_equiv_project/partial_equiv/wandb/run-20240323_060932-0vh427nx/files/config.yaml"
with open(file) as f:
    used_args = yaml.safe_load(f)
args = dict()
for k in used_args:
    v = used_args[k]
    if isinstance(v, dict) and not k.startswith("_"):
        if "." in k:
            temp = args
            for _k in k.split(".")[:-1]:
                if temp.get(_k) is None:
                    temp[_k] = dict()
                temp = temp[_k]
            temp[k.split(".")[-1]] = v["value"]
        else:
            args[k] = v["value"]
args = EasyDict(args)
print(args)
torch.manual_seed(args.seed)
np.random.seed(args.seed)

os.environ["WANDB_MODE"] = "dryrun"
os.environ["HYDRA_FULL_ERROR"] = "1"

wandb.init(
    project="partial_equiv",
    entity="kim-hyunsu",
    reinit=True,
)


{'model': {'rot': 3, 'variational': True, 'version': 'v1.2', 'partial': False, 'insta': False, 'insta_params': {'num_samples': 1, 'lambda_entropy': 0.0001, 'h_min': -1.5, 'h_max': 2}}, 'seed': 2024, 'task': 'flowers', 'train': {'batch_size': 64, 'epochs': 700, 'lamda': 0.01, 'lamda2': 0.01, 'lr': 0.0002, 'lr_probs': 2e-05, 'valid_every': 10, 'weight_decay': 0.01, 'do': True}, 'wandb': {'entity': 'kim-hyunsu', 'mode': 'online', 'project': 'partial_equiv'}, 'type': 'color', 'dataset': 'Flowers102', 'pretrained': None, 'device': 'cuda', 'comment': '', 'no_workers': 1}


In [21]:
### Get samples from test dataset
dataset = generate_102flower_data(size=224)

test_loader = torch.utils.data.DataLoader(
        dataset["test"],
        batch_size=1,
        shuffle=False,
        num_workers=1,
    )

In [24]:
gumbel_no_iterations = math.ceil(len(dataset["train"]) / float(args.train.batch_size))  # Iter per epoch
gumbel_no_iterations = args.train.epochs * gumbel_no_iterations
model = ResNet18_partial(pretrained=False, progress=False, rotations=args.model.rot, num_classes=102,
                groupcosetmaxpool=True, separable=True,
                gumbel_no_iterations=gumbel_no_iterations,
                version=args.model.version
        ).to(args.device)
# load checkpoint
# model_checkpoint = wandb.restore('checkpoint.pt', run_path)
# model.load_state_dict(
#             torch.load(model_checkpoint.name, map_location=args.device)["model"],
#             strict=True,
#         )
model = torch.nn.DataParallel(model)
checkpoint_path = "/mnt/home/yegonkim/home/partial_equiv_project/partial_equiv/wandb/run-20240323_060932-0vh427nx/files/checkpoint.pt"
model.load_state_dict(
            torch.load(checkpoint_path, map_location=args.device)["model"],
            strict=True,
        )


<All keys matched successfully>

In [25]:
def get_group_elements(target_labels, count):
    samples = []
    for target in target_labels:
        count = 9
        with torch.no_grad():
            for images, labels in tqdm(test_loader):
                images = images.to(args.device)
                labels = labels.to(args.device)
                outputs = model(images)
                prob_class = torch.softmax(outputs, dim=-1)[0,labels[0]]
                if prob_class < 0.01:
                    continue
                if torch.any(labels==target) and count > 0:
                    count -= 1
                    continue
                if torch.any(labels==target) and count == 0:
                    images = images[labels==target]
                    model(images)
                    probs = None
                    module = None
                    for m in model.modules():
                        if getattr(m, "entropy", None) is None:
                            continue
                        probs = m.probs_all
                        module = m
                    assert len(probs.shape) == 2
                    prob_rotations = torch.arange(0,module.out_rotations).to(probs).view(1,-1)
                    prob_rotations = torch.softmax(prob_rotations/probs,dim=-1)
                    sample_rotation = (prob_rotations > (1/(module.out_rotations+1))).float()
                    samples.append(sample_rotation.sum(-1).squeeze().detach().cpu())
                    break
    return samples

In [26]:
def plot_bar(jitterings, predictions):
    plt.figure(dpi=300)

    plt.bar(jitterings, predictions, width=0.08, color="C1")

    plt.grid(False)

    plt.ylabel('Confidence for Corresponding Class')
    plt.xlabel('Hue Shift')
    plt.xticks(jitterings)

    plt.tight_layout()

    plt.savefig(f'images/group_elements_bar.png')
    plt.show()

In [27]:
target_labels = torch.arange(20)
samples = get_group_elements(target_labels, 0)
print(samples)
# plot_bar(target_labels, samples)

  0%|          | 9/6149 [00:02<26:38,  3.84it/s]  
  0%|          | 29/6149 [00:01<04:04, 25.03it/s]
  1%|          | 71/6149 [00:01<02:43, 37.11it/s]
  1%|▏         | 90/6149 [00:02<02:33, 39.53it/s]
  2%|▏         | 125/6149 [00:02<02:15, 44.30it/s]
  3%|▎         | 170/6149 [00:03<02:09, 46.11it/s]
100%|██████████| 6149/6149 [01:50<00:00, 55.88it/s]
  4%|▎         | 216/6149 [00:04<02:00, 49.41it/s]
  5%|▍         | 282/6149 [00:05<01:58, 49.56it/s]
100%|██████████| 6149/6149 [01:44<00:00, 58.70it/s]
  5%|▌         | 336/6149 [00:06<01:47, 54.18it/s]
100%|██████████| 6149/6149 [01:44<00:00, 59.03it/s]
  8%|▊         | 465/6149 [00:08<01:43, 54.78it/s]
  8%|▊         | 501/6149 [00:09<01:41, 55.56it/s]
  8%|▊         | 522/6149 [00:09<01:41, 55.30it/s]
100%|██████████| 6149/6149 [01:43<00:00, 59.27it/s]
 10%|▉         | 592/6149 [00:10<01:41, 54.95it/s]
 11%|█         | 657/6149 [00:11<01:37, 56.43it/s]
 12%|█▏        | 720/6149 [00:12<01:36, 56.45it/s]
 12%|█▏        | 758/6149 [00:

[tensor(2.), tensor(2.), tensor(2.), tensor(2.), tensor(2.), tensor(2.), tensor(1.), tensor(2.), tensor(2.), tensor(2.), tensor(2.), tensor(2.), tensor(2.), tensor(2.), tensor(1.), tensor(2.)]



