In [1]:
import os
import random
import numpy as np
import torch
import torch.optim as optim
import argparse
import torch.nn.functional as F
from torch.utils.data import DataLoader
from functools import partial

import net
from hyptorch.pmath import dist_matrix
from proxy_anchor import dataset
from proxy_anchor.utils import calc_recall_at_k
from sampler import UniqueClassSempler
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset

In [2]:
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
path = '/data/xuyunhao/datasets'
ds = 'CUB'
num_samples = 2
bs = 200
lr = 1e-5
t = 0.2
emb = 512
ep = 100
local_rank = 0
workers = 4
optimizer = 'adamw'
lr_decay_step = 10
lr_decay_gamma = 0.5

model =  'resnet34'
hyp_c = 0
clip_r  = 2.3
resize = 224
crop = 224
gpu_id = 5
bn_freeze = 1
freezer = True

# 损失函数

In [4]:
def contrastive_loss(e0, e1, tau):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode
    dist_e = lambda x, y: -torch.cdist(x, y, p=2)
    
    dist_e0 = dist_e(e0, e0)   
    dist_e1 = dist_e(e0, e1)

    bsize = e0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    
    logits00 = dist_e0 / tau - eye_mask
    logits01 = dist_e1 / tau
    
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    stats = {
        "logits/min": logits01.min().item(),
        "logits/mean": logits01.mean().item(),
        "logits/max": logits01.max().item(),
        "logits/acc": (logits01.argmax(-1) == target).float().mean().item(),
    }
    return loss, stats

# 模型

In [5]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.nn.init as init
import hyptorch.nn as hypnn
from torchvision.models import resnet18
from torchvision.models import resnet34
from torchvision.models import resnet50
from torchvision.models import resnet101

## resnet18

In [6]:
class Resnet18(nn.Module):
    def __init__(self,embedding_size, pretrained=True, bn_freeze = True):
        super(Resnet18, self).__init__()

        self.model = resnet18(pretrained)
        self.embedding_size = embedding_size
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)
        
        self.Elayer = NormLayer()
        self.model.embedding = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Elayer)

        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)


    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        
        x = x.view(x.size(0), -1)
        x_e = self.model.embedding(x)
        return x_e

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding[0].weight, mode='fan_out')
        init.constant_(self.model.embedding[0].bias, 0)

class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)

# resnet34

In [7]:
class Resnet34(nn.Module):
    def __init__(self,embedding_size, pretrained=True, bn_freeze = True):
        super(Resnet34, self).__init__()

        self.model = resnet34(pretrained)
        self.embedding_size = embedding_size
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)
        
        self.Elayer = NormLayer()
        self.model.embedding = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Elayer)

        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = avg_x + max_x
        
        x = x.view(x.size(0), -1)
        x_e = self.model.embedding(x)
        return x_e

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding[0].weight, mode='fan_out')
        init.constant_(self.model.embedding[0].bias, 0)

class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)

In [8]:
def evaluate(get_emb_f, ds_name):
    emb_head = get_emb_f(ds_type="eval")
    recall_head = get_recall(*emb_head, ds_name)
    return recall_head

def get_recall(e, y, ds_name):
    if ds_name == "CUB" or ds_name == "Cars":
        k_list = [1, 2, 4, 8, 16, 32]
    elif ds_name == "SOP":
        k_list = [1, 10, 100, 1000]

    dist_m = torch.empty(len(e), len(e), device="cuda")
    for i in range(len(e)):
        dist_m[i : i + 1] = -torch.cdist(e[i : i + 1], e, p=2)

    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall[0]

def get_emb(
    model,
    ds,
    path,
    ds_type="eval",
    world_size=1,
    num_workers=8,
):
    eval_tr = dataset.utils.make_transform(
        is_train = True, 
        is_inception = (model == 'bn_inception')
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    e, y = eval_dataset(model, dl_eval)
    y = y.cuda()
    model.train()
    return e, y

def eval_dataset(model, dl):
    all_xe, all_y = [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            e= model(x)
            all_xe.append(e)
        all_y.append(y)
    return torch.cat(all_xe), torch.cat(all_y)

In [9]:
torch.cuda.set_device(gpu_id)
world_size = int(os.environ.get("WORLD_SIZE", 1))

train_tr = dataset.utils.make_transform(
    is_train = True, 
    is_inception = (model == 'bn_inception')
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[ds]
ds_train = ds_class(path, "train", train_tr)

sampler = UniqueClassSempler(
    ds_train.ys, num_samples, local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=bs,
    num_workers=workers,
    pin_memory=True,
)

if model.find('resnet18')+1:
    model = Resnet18(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze).cuda().train() 
elif model.find('resnet34')+1:
    model = Resnet34(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze).cuda().train() 
elif model.find('resnet50')+1:
    model = Resnet50(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze).cuda().train() 
elif model.find('resnet101')+1:
    model = Resnet101(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze).cuda().train() 
loss_f = partial(contrastive_loss, tau=t)

get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=path,
    num_workers=workers,
    world_size=world_size,
)
if freezer == True:
    embedding_param = list(model.model.embedding.parameters())
    for param in list(set(model.parameters()).difference(set(embedding_param))):
        param.requires_grad = False
    optimizer = optim.AdamW(embedding_param, lr=lr)
else:
    optimizer = optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma = lr_decay_gamma)
print("Training for {} epochs.".format(ep))

r0= evaluate(get_emb_f, ds)
print("The recall before train: ", r0)

losses_list = []
best_recall= 0
best_epoch = 0

for epoch in range(0, ep):
    model.train()
    if bn_freeze:
        modules = model.model.modules()
        for m in modules: 
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    losses_per_epoch = []
    sampler.set_epoch(epoch)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // num_samples, num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s1 = y[:, 0].tolist()
        assert len(set(s1)) == len(s1)

        x = x.cuda(non_blocking=True)
        e = model(x)
        e = e.view(len(x) // num_samples, num_samples, emb)
        loss = 0
        for i in range(num_samples):
            for j in range(num_samples):
                if i != j:
                    l, st = loss_f(e[:, i], e[:, j])
                    loss += l
                    stats_ep.append({**st, "loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        
    scheduler.step()        
    rh= evaluate(get_emb_f, ds)
    stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
    stats_ep = {"recall": rh, **stats_ep}
    if rh > best_recall :
        best_recall = rh
        best_epoch = epoch
    print("epoch:",epoch,"recall: ", rh)
    print("best epoch:",best_epoch,"best recall: ", best_recall)

Training for 100 epochs.
[0.4176232275489534, 0.5406819716407832, 0.6561444969615124, 0.762322754895341, 0.8529709655638082, 0.9098582039162728]
The recall before train:  0.4176232275489534
[0.41812964213369347, 0.5408507765023632, 0.6568197164078325, 0.7641796083727211, 0.8507765023632681, 0.9076637407157326]
epoch: 0 recall:  0.41812964213369347
best epoch: 0 best recall:  0.41812964213369347
[0.4161039837947333, 0.5410195813639432, 0.6617150573936529, 0.7667116812964213, 0.849763673193788, 0.9137407157326131]
epoch: 1 recall:  0.4161039837947333
best epoch: 0 best recall:  0.41812964213369347
[0.4150911546252532, 0.5401755570560433, 0.6583389601620526, 0.7721134368669818, 0.8553342336259284, 0.9144159351789332]
epoch: 2 recall:  0.4150911546252532
best epoch: 0 best recall:  0.41812964213369347
[0.41188386225523294, 0.5349426063470628, 0.6580013504388926, 0.7668804861580013, 0.8561782579338285, 0.9152599594868333]
epoch: 3 recall:  0.41188386225523294
best epoch: 0 best recall:  0.4

[0.41644159351789334, 0.5357866306549629, 0.6635719108710331, 0.7648548278190412, 0.8436866981769074, 0.9064821066846726]
epoch: 39 recall:  0.41644159351789334
best epoch: 34 best recall:  0.42218095881161377
[0.4169480081026334, 0.5351114112086428, 0.6585077650236327, 0.7609723160027009, 0.8430114787305875, 0.9061444969615124]
epoch: 40 recall:  0.4169480081026334
best epoch: 34 best recall:  0.42218095881161377
[0.4112086428089129, 0.537812288993923, 0.6546252532072924, 0.7629979743416611, 0.8433490884537475, 0.9112086428089129]
epoch: 41 recall:  0.4112086428089129
best epoch: 34 best recall:  0.42218095881161377
[0.40597569209993245, 0.525658338960162, 0.6485482781904118, 0.7537137069547603, 0.8438555030384876, 0.9112086428089129]
epoch: 42 recall:  0.40597569209993245
best epoch: 34 best recall:  0.42218095881161377
[0.40968939905469276, 0.5322417285617825, 0.6477042538825118, 0.7552329507089804, 0.8477380148548278, 0.9107022282241729]
epoch: 43 recall:  0.40968939905469276
best 

[0.4042876434841323, 0.5337609723160027, 0.6595205941931127, 0.7662052667116813, 0.8517893315327482, 0.912052667116813]
epoch: 79 recall:  0.4042876434841323
best epoch: 65 best recall:  0.4238690074274139
[0.4113774476704929, 0.5320729237002025, 0.6547940580688724, 0.7651924375422012, 0.8433490884537475, 0.9085077650236327]
epoch: 80 recall:  0.4113774476704929
best epoch: 65 best recall:  0.4238690074274139
[0.4012491559756921, 0.5305536799459825, 0.6539500337609723, 0.7662052667116813, 0.8546590141796083, 0.9157663740715732]
epoch: 81 recall:  0.4012491559756921
best epoch: 65 best recall:  0.4238690074274139
[0.40698852126941254, 0.5351114112086428, 0.6612086428089129, 0.7682309250506415, 0.8531397704253882, 0.9181296421336934]
epoch: 82 recall:  0.40698852126941254
best epoch: 65 best recall:  0.4238690074274139
[0.40580688723835245, 0.5217758271438218, 0.6465226198514518, 0.7535449020931803, 0.8431802835921675, 0.9086765698852127]
epoch: 83 recall:  0.40580688723835245
best epoch