In [175]:
import argparse
import json
import os
import random
import datetime

import numpy as np
import torch
import torch.nn.functional as F
# import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

from datasets import *
from models import *
from utils import *
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.metrics import classification_report

In [177]:
# for jupyter notebook set arguments
parser = argparse.ArgumentParser(description='Masked contrastive learning.')
args = argparse.Namespace(description='Masked contrastive learning.')

# training config:
parser.add_argument('--dataset', default='cifartoy_good', choices=['cifar100', 'cifartoy_bad', 'cifartoy_good', 'cars196', 'sop_split1', 'sop_split2', 'imagenet32'], type=str, help='train dataset')
parser.add_argument('--data_path', default='./data', type=str, help='train dataset')

# model configs: [Almost fixed for all experiments]
parser.add_argument('--arch', default='resnet18')
parser.add_argument('--dim', default=256, type=int, help='feature dimension')
parser.add_argument('--K', default=8192, type=int, help='queue size; number of negative keys')
parser.add_argument('--m', default=0.99, type=float, help='moco momentum of updating key encoder')
parser.add_argument('--t0', default=0.1, type=float, help='softmax temperature for training')

# train configs:
parser.add_argument('--lr', '--learning-rate', default=0.02, type=float, metavar='LR', help='initial learning rate', dest='lr')
# parser.add_argument('--epochs', default=200, type=int, metavar='N', help='number of total epochs')
parser.add_argument('--epochs', default=300, type=int, metavar='N', help='number of total epochs')
parser.add_argument('--warm_up', default=5, type=int, metavar='N', help='number of warmup epochs')
parser.add_argument('--batch_size', default=128, type=int, metavar='N', help='mini-batch size')
parser.add_argument('--wd', default=5e-4, type=float, metavar='W', help='weight decay')
parser.add_argument('--aug_q', default='strong', type=str, help='augmentation strategy for query image')
parser.add_argument('--aug_k', default='weak', type=str, help='augmentation strategy for key image')
parser.add_argument('--gpu_id', default='0', type=str, help='gpuid')

# method configs:
parser.add_argument('--mode', default='maskcon', type=str, choices=['maskcon', 'grafit', 'coins'], help='training mode')

# maskcon-specific hyperparameters:
parser.add_argument('--w', default=0.5, type=float, help='weight of self-invariance')  # not-used if maskcon
parser.add_argument('--t', default=0.05, type=float, help='softmax temperature weight for soft label')

# logger configs
# parser.add_argument('--wandb_id', default='logs',type=str, help='wandb user id')
parser.add_argument('--logs', default='logs',type=str, help='log directory file name')
parser.add_argument('--NO_TSNE', action='store_true', help='Enable TSNE visualization')

_StoreTrueAction(option_strings=['--NO_TSNE'], dest='NO_TSNE', nargs=0, const=True, default=False, type=None, choices=None, required=False, help='Enable TSNE visualization', metavar=None)

In [190]:
def test_retrieval(encoder, test_loader, K, args, epoch, chunks=10, num_samples=200):
    encoder.eval()
    feature_bank, target_bank = [], []
    with torch.no_grad():
        for i, (image, _, fine_label) in enumerate(test_loader):
            image = image.cuda(non_blocking=True)
            label = fine_label.cuda(non_blocking=True)
            output = encoder(image, feat=True)
            feature_bank.append(output)
            target_bank.append(label)
        
        feature = torch.cat(feature_bank, dim=0)
        label = torch.cat(target_bank, dim=0).contiguous()
    
    label = label.unsqueeze(-1)
    feat_norm = F.normalize(feature, dim=1)
    split = torch.tensor(np.linspace(0, len(feat_norm), chunks + 1, dtype=int), dtype=torch.long).to(feature.device)
    recall = [[] for i in K]
    ids = [torch.tensor([]).to(feature.device) for i in K]
    correct = [torch.tensor([]).to(feature.device) for i in K]
    k_max = np.max(K)

    # Collect predictions for classification report
    all_preds = []

    # Initialize dictionaries to store class-wise correct retrievals and total samples
    num_classes = label.max().item() + 1
    class_correct = {k: torch.zeros(num_classes).to(feature.device) for k in K}
    class_total = torch.zeros(num_classes).to(feature.device)

    with torch.no_grad():
        for j in range(chunks):
            torch.cuda.empty_cache()
            part_feature = feat_norm[split[j]: split[j + 1]]
            similarity = torch.einsum('ab,bc->ac', part_feature, feat_norm.T)

            topmax = similarity.topk(k_max + 1)[1][:, 1:]
            del similarity
            retrievalmax = label[topmax].squeeze()
            for k, i in enumerate(K):
                anchor_label = label[split[j]: split[j + 1]].repeat(1, i)
                topi = topmax[:, :i]
                retrieval_label = retrievalmax[:, :i]
                correct_i = torch.sum(anchor_label == retrieval_label, dim=1, keepdim=True)
                correct[k] = torch.cat([correct[k], correct_i], dim=0)
                ids[k] = torch.cat([ids[k], topi], dim=0)

                # Collect predictions for classification report
                if k == 0:  # Assuming we want the predictions for the first K value
                    all_preds.extend(retrieval_label[:, 0].cpu().numpy())

                 # Update class-wise correct retrievals
                for c in range(num_classes):
                    class_mask = (anchor_label[:, 0] == c)
                    class_correct[i][c] += torch.sum(correct_i[class_mask])

            # Update class-wise total samples
            for c in range(num_classes):
                class_total[c] += torch.sum(label[split[j]: split[j + 1]] == c)
                print(f"Class {c}: {class_correct[1][c]} / {class_total[c]}")

        # Calculate recall @ K for each class
        # recall_per_class = {k: torch.zeros(num_classes) for k in K}
        recall_per_class = {k: {} for k in K}
        for k in K:
            for c in range(num_classes):
                recall_per_class[k][c] = float((class_correct[k][c] > 0).int().sum() / class_total[c])
                # if class_total[c] > 0:
                #     recall_per_class[k][c] = (class_correct[k][c] / class_total[c]).item()
                # else:
                #     recall_per_class[k][c] = 0.0
        print(recall_per_class)

        # Print Recall@k for each class
        # for k in K:
        #     print(f"\nRecall@{k} for each class:")
        #     for c in range(num_classes):
        #         print(f"Class {c}: {recall_per_class[k][c]:.4f}")

        # calculate recall @ K
        num_sample = len(feat_norm)
        for k, i in enumerate(K):
            acc_k = float((correct[k] > 0).int().sum() / num_sample)
            recall[k] = acc_k

    # Generate classification report
    y_true = label.cpu().numpy().flatten()
    y_pred = np.array(all_preds)
    class_report = classification_report(y_true, y_pred, digits=4, zero_division=0)


    if not args.NO_TSNE:

        # Initialize lists to store the subsets
        feature_subset_list = []
        label_subset_list = []

        # Get unique labels
        unique_labels = torch.unique(label)

        # Iterate over each unique label
        for unique_label in unique_labels:
            # Get indices of samples with the current label
            label_indices = (label == unique_label).nonzero(as_tuple=True)[0]
            
            # Select the first 100 samples (or fewer if less than 100 samples are available)
            selected_indices = label_indices[:num_samples]
            
            # Append the selected features and labels to the lists
            feature_subset_list.append(feat_norm[selected_indices])
            label_subset_list.append(label[selected_indices])

        # Concatenate the lists to form the final subsets
        feature_subset = torch.cat(feature_subset_list, dim=0)
        label_subset = torch.cat(label_subset_list, dim=0)
        
        # Apply t-SNE on the subset of features
        tsne = TSNE(n_components=2, perplexity=30, random_state=0)
        tsne_features = tsne.fit_transform(feature_subset.cpu().numpy())
        tsne_labels = label_subset.cpu().numpy()



        # tSNE save path
        tsne_save_path = os.path.join(args.logs, args.results_dir, 'test_TSNE')

        if not os.path.exists(tsne_save_path):
            os.mkdir(tsne_save_path)

        # Plot t-SNE
        #  dont show the plot only save
        colors = ['#FF1B1B', '#229122', '#0909FF', '#0DC2C2', '#C107C1', '#BFBF01', '#040404', '#F091F0']

        plt.ioff()
        plt.figure(figsize=(8, 8))

        plt.scatter(tsne_features[:, 0], tsne_features[:, 1], c=tsne_labels, cmap=ListedColormap(colors))
        # Get unique labels
        unique_labels = np.unique(tsne_labels)
        # Create legend patches
        legend_patches = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=colors[i], markersize=10, label=f'{i}') for i in unique_labels]
        # Add legend to plot
        plt.legend(handles=legend_patches, title="Classes", loc='upper left')

        plt.title(f't-SNE visualization {args.dataset}')
        plt.savefig(os.path.join(tsne_save_path, f'tsne_epoch_{epoch}.png'))

    return recall, recall_per_class, class_report


In [191]:
def test_proc(args, model, train_loader, test_loader):
    """### Start testing"""
    # define optimizer
    epoch_start = 0

    test_logs = open(f'{args.logs}/{args.results_dir}/test_logs.txt', 'w')

    # training loop
    best_retrieval_top1 = 0
    best_retrieval_top2 = 0
    best_retrieval_top5 = 0
    best_retrieval_top10 = 0
    best_retrieval_top50 = 0
    best_retrieval_top100 = 0

    model.initiate_memorybank(train_loader)

    for epoch in range(epoch_start, args.epochs):
        if epoch % 1 == 0:
            # weight save path
            weight_save_path = os.path.join(args.logs, args.results_dir, 'weight')
            model_save_path = os.path.join(weight_save_path, f'model_epoch_{epoch}.pth')
            # print(f'loading model from {model_save_path}')
            model_state_dict = torch.load(model_save_path)
            model.load_state_dict(model_state_dict)
            print(f'loaded successfully = {model_save_path}')
            
            retrieval_topk, retrieval_topk_perclass, class_report = test_retrieval(encoder=model.encoder_q, test_loader=test_loader, K=[1, 2, 5, 10, 50, 100], args=args, epoch=epoch, chunks=10)
            retrieval_top1, retrieval_top2, retrieval_top5, retrieval_top10, retrieval_top50, retrieval_top100 = retrieval_topk
            if retrieval_top1 > best_retrieval_top1:
                best_retrieval_top1 = best_retrieval_top1
            if retrieval_top2 > best_retrieval_top2:
                best_retrieval_top2 = best_retrieval_top2
            if retrieval_top5 > best_retrieval_top5:
                best_retrieval_top5 = best_retrieval_top5
            if retrieval_top10 > best_retrieval_top10:
                best_retrieval_top10 = best_retrieval_top10
            if retrieval_top50 > best_retrieval_top50:
                best_retrieval_top50 = best_retrieval_top50
            if retrieval_top100 > best_retrieval_top100:
                best_retrieval_top100 = best_retrieval_top100

            # save statistics
            print(f'Epoch [{epoch}/{args.epochs}]: R@1: {retrieval_top1:.4f}, R@2: {retrieval_top2:.4f}, R@5: {retrieval_top5:.4f}, R@10: {retrieval_top10:.4f},  R@50: {retrieval_top50:.4f},R@100: {retrieval_top100:.4f}')
            test_logs.write(f'Epoch [{epoch}/{args.epochs}]:\n')
            test_logs.write(f'    R@1: {retrieval_top1:.4f}, R@2: {retrieval_top2:.4f}, R@5: {retrieval_top5:.4f}, R@10: {retrieval_top10:.4f},  R@50: {retrieval_top50:.4f},R@100: {retrieval_top100:.4f}\n')

            headers = "    ".join([f"Per_Class R@{k}:" for k in retrieval_topk_perclass.keys()])
            test_logs.write(headers + "\n")

            # Get the maximum number of classes
            max_classes = max(len(v) for v in retrieval_topk_perclass.values())

            for class_id in range(max_classes):
                line = []
                for k in retrieval_topk_perclass.keys():
                    if class_id in retrieval_topk_perclass[k]:
                        line.append(f"Class-{class_id}: {retrieval_topk_perclass[k][class_id]:.4f}")
                    else:
                        line.append(f"Class-{class_id}: N/A")
                test_logs.write("    ".join(line) + "\n")
            
            test_logs.write("\n")

            test_logs.write(f"Classification Report:\n{class_report}\n")
            
            test_logs.flush()

    return model


In [192]:
def test():
    args = parser.parse_args([
                                '--epochs','500',
                                '--dataset', 'cifartoy_bad',
                                '--NO_TSNE', # Action argument to disable t-SNE visualization
                            ])
    

    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    random.seed(1228)
    torch.manual_seed(1228)
    torch.cuda.manual_seed_all(1228)
    np.random.seed(1228)
    torch.backends.cudnn.benchmark = True

    """Define train/test"""
    query_transform = get_augment(args.dataset, args.aug_q)
    key_transform = get_augment(args.dataset, args.aug_k)
    test_transform = get_augment(args.dataset)

    if args.dataset == 'cars196':
        train_dataset = CARS196(root=args.data_path, split='train', transform=DMixTransform([key_transform, query_transform], [1, 1]))
        test_dataset = CARS196(root=args.data_path, split='test', transform=test_transform)
        args.num_classes = 8
        args.size = 224

    elif args.dataset == 'cifar100':
        train_dataset = CIFAR100(root=args.data_path, download=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
        test_dataset = CIFAR100(root=args.data_path, train=False, download=True, transform=test_transform)
        args.num_classes = 20
        args.size = 32

    elif args.dataset == 'cifartoy_good':
        train_dataset = CIFARtoy(root=args.data_path, split='good', download=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
        test_dataset = CIFARtoy(root=args.data_path, split='good', train=False, download=True, transform=test_transform)
        args.num_classes = 2
        args.size = 32

    elif args.dataset == 'cifartoy_bad':
        train_dataset = CIFARtoy(root=args.data_path, split='bad', download=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
        test_dataset = CIFARtoy(root=args.data_path, split='bad', train=False, download=True, transform=test_transform)
        args.num_classes = 2
        args.size = 32

    elif args.dataset == 'sop_split2':
        train_dataset = StanfordOnlineProducts(split='2', root=args.data_path, train=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
        test_dataset = StanfordOnlineProducts(split='2', root=args.data_path, train=False, transform=test_transform)
        args.num_classes = 12
        args.size = 224

    elif args.dataset == 'sop_split1':
        train_dataset = StanfordOnlineProducts(split='1', root=args.data_path, train=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
        test_dataset = StanfordOnlineProducts(split='1', root=args.data_path, train=False, transform=test_transform)
        args.num_classes = 12
        args.size = 224

    elif args.dataset == 'imagenet32':
        train_dataset = ImageNetDownSample(root=args.data_path, train=True, transform=DMixTransform([key_transform, query_transform], [1, 1]))
        test_dataset = ImageNetDownSample(root=args.data_path, train=False, transform=test_transform)
        args.num_classes = 12
        args.size = 32
        
    else:
        raise ValueError(f'{args.dataset} is not supported now!')

    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True, pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
    # train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True, pin_memory=True)
    # test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, pin_memory=True)

    # create trainer
    trainer = MaskCon(num_classes_coarse=args.num_classes, dim=args.dim, K=args.K, m=args.m, T1=args.t0, arch=args.arch, size=args.size, T2=args.t, mode=args.mode).cuda()

    # create logs directory
    now = datetime.datetime.now().strftime('%y%m%d-%p%I%M')
    args.results_dir = f'240627-PM0412-arch_{args.arch}-data_{args.dataset}-{args.mode}'

    if not os.path.exists(args.logs):
        os.mkdir(args.logs)
    if not os.path.exists(f'{args.logs}/{args.results_dir}'):
        # Give an error message
        print(f'No logs found in {args.logs}/{args.results_dir}')
        return None

    print(args)

    test_proc(args, trainer, train_loader, test_loader)

In [193]:
test()

Files already downloaded and verified
Files already downloaded and verified




Namespace(dataset='cifartoy_bad', data_path='./data', arch='resnet18', dim=256, K=8192, m=0.99, t0=0.1, lr=0.02, epochs=500, warm_up=5, batch_size=128, wd=0.0005, aug_q='strong', aug_k='weak', gpu_id='0', mode='maskcon', w=0.5, t=0.05, logs='logs', NO_TSNE=True, num_classes=2, size=32, results_dir='240627-PM0412-arch_resnet18-data_cifartoy_bad-maskcon')
Initiate memory bank!
loaded successfully = logs/240627-PM0412-arch_resnet18-data_cifartoy_bad-maskcon/weight/model_epoch_0.pth
Class 0: 458.0 / 800.0
Class 1: 0.0 / 0.0
Class 2: 0.0 / 0.0
Class 3: 0.0 / 0.0
Class 4: 0.0 / 0.0
Class 5: 0.0 / 0.0
Class 6: 0.0 / 0.0
Class 7: 0.0 / 0.0
Class 0: 564.0 / 1000.0
Class 1: 329.0 / 600.0
Class 2: 0.0 / 0.0
Class 3: 0.0 / 0.0
Class 4: 0.0 / 0.0
Class 5: 0.0 / 0.0
Class 6: 0.0 / 0.0
Class 7: 0.0 / 0.0
Class 0: 564.0 / 1000.0
Class 1: 546.0 / 1000.0
Class 2: 0.0 / 0.0
Class 3: 0.0 / 0.0
Class 4: 0.0 / 0.0
Class 5: 0.0 / 0.0
Class 6: 183.0 / 400.0
Class 7: 0.0 / 0.0
Class 0: 564.0 / 1000.0
Class 1: 

KeyboardInterrupt: 