In [2]:
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn.functional as F
from tqdm import tqdm
from math import sqrt
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
from evaluation import Evaluator

  from .autonotebook import tqdm as notebook_tqdm


In [50]:
%matplotlib inline
%pylab inline

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style="white", font="Arial")
colors = sns.color_palette("Paired", n_colors=12).as_hex()

Populating the interactive namespace from numpy and matplotlib


In [4]:

class full_model(nn.Module):
    def __init__(self, args, network, num_classes=1000, num_feats=512 ) -> None:
        super(full_model, self).__init__()



        self.network = network
        if args.lgm:
            self.criterion = LGM(args, num_classes, num_feats)

        if args.vc:
            self.criterion = VariationalClassification(args, num_classes, num_feats)
        
    def forward(self, x):
        return self.network(x)

        # tsampled_accuracy += ((Tsampled < 0) == torch.ones_like(Tsampled)).sum().item()
        # treal_accuracy += ((Treal > 0) == torch.ones_like(Treal)).sum().item()

        
        


class Discriminators_2l(nn.Module):
    def __init__(self, num_classes, feat_dim):
        super(Discriminators_2l, self).__init__() 
        self.hidden_dim = feat_dim
        self.W1 = nn.Parameter(torch.randn(num_classes, self.hidden_dim, feat_dim))
        self.b1 = nn.Parameter(torch.zeros(num_classes, self.hidden_dim))
        self.W2 = nn.Parameter(torch.randn(num_classes, self.hidden_dim))
        self.b2 = nn.Parameter(torch.zeros(num_classes))
        self.relu  = nn.Tanh()
        # self.classifiers = []
        # for i in range(num_classes):
        #     self.classifiers.append(nn.Sequential(nn.Linear(feat_dim, feat_dim), nn.ReLU(), nn.Linear(feat_dim, num_classes)))
            
        # self.dropout = nn.Dropout(0.2)
    def forward(self, Z, y):
        #Z Bxfeat_dim
        #y B

        w1 = self.W1[y, :, :] #B x 2*feat_dim x feat_dim
        w2 = self.W2[y, :] #B x 2* feat_dim x 1
        b1 = self.b1[y, :] # B x 2* feat
        b2 = self.b2[y, ] # B x 1

        # import pdb
        # pdb.set_trace()

        op = self.relu((w1 * Z.unsqueeze(1)).sum(-1) + b1)
        op =(w2*op).sum(-1) + b2
        # import pdb
        # pdb.set_trace()
        # op = (w2* Z).sum(-1) + b2

        return op

    def reset(self,):
        torch.nn.init.xavier_uniform_(self.W1)
        torch.nn.init.zeros_(self.b1)
        torch.nn.init.xavier_uniform_(self.W2)
        torch.nn.init.zeros_(self.b2)




def pairwise_dist(mat, num_classes):
    dist = 0.0

    for i in range(num_classes-1):
        for j in range(num_classes-1, i, -1):

            dist += (mat[i, :]- mat[j, :]).pow(2).sum(-1).sqrt()
    return dist




class Discriminators_1l(nn.Module):
    def __init__(self, num_classes, feat_dim):
        super(Discriminators_1l, self).__init__() 
        self.W1 = nn.Parameter(torch.randn(num_classes, feat_dim))
        self.b1 = nn.Parameter(torch.zeros(num_classes))
        # self.dropout = nn.Dropout(0.2)
    def forward(self, Z, y):
        #Z Bxfeat_dim
        #y B

        w1 = self.W1[y, :] #B x feat_dim x feat_dim
        b1 = self.b1[y] # B x feat



        # import pdb
        # pdb.set_trace()
        # Z_ = torch.concat([Z, Z*Z], dim=-1)

        op = (w1 * Z).sum(-1) + b1
        # op =(w2*op).sum(-1) + b1
        # import pdb
        # pdb.set_trace()
        # op = (w2* Z).sum(-1) + b2

        return op
    def reset(self,):
        torch.nn.init.xavier_uniform_(self.W1)
        torch.nn.init.zeros_(self.b1)

class VariationalClassification(nn.Module):
    def __init__(self, args, num_classes, feat_dim , ) -> None:
        super(VariationalClassification, self).__init__()
        self.NLL = nn.CrossEntropyLoss()
        self.VC = AdversarialContrastiveLoss(num_classes, feat_dim, alpha=0.0)

        self.classifier = Discriminators_2l(num_classes=num_classes, feat_dim=feat_dim) if args.disc_layers == 2 else Discriminators_1l(num_classes=num_classes, feat_dim=feat_dim)

        self.disc_optimizer = torch.optim.Adam( [p for p in self.classifier.parameters()]   ,  lr=9e-4, weight_decay=1e-4)

        self.args = args
    def forward(self,outputs, targets ):
        args = self.args
        logits, self.real, self.sampled, likelihood = self.VC(outputs, label= targets, detach_features=True)
        self.targets = targets

        Treal = self.classifier(outputs, targets).squeeze(-1)
        # from losses import pairwise_dist
        # dist = pairwise_dist(self.criterion.centers, 10)


        # loss_1 = self.cross_entropy(logits, targets)  + args.l1*(likelihood.mean())   #+ 0.001 * (0.2*torch.exp(self.criterion.log_covs).sum() - 640*self.criterion.log_covs.sum()) 

        loss_1 = self.NLL(logits, targets)  + args.l1*(likelihood.mean())  #+ 0.01 * (torch.exp(self.criterion.log_covs).mean() - dist) 
        loss_2 = Treal.mean()
        loss = loss_1 + args.l2*loss_2
        return loss, logits

    def discriminator_train(self):
        Tsampled = self.classifier(self.sampled.detach(), self.targets).squeeze(-1)
        Treal = self.classifier(self.real.detach(), self.targets).squeeze(-1)
        self.disc_optimizer.zero_grad()
        dual_loss = ( F.binary_cross_entropy_with_logits(Treal, torch.ones_like(Treal)) + F.binary_cross_entropy_with_logits(Tsampled, torch.zeros_like(Tsampled)) )
        dual_loss.backward()
        # print("Treal acc: {}".format(((Treal > 0) == torch.ones_like(Treal)).detach().cpu().numpy().mean()))
        # print("Tsampled acc: {}".format(((Tsampled < 0) == torch.zeros_like(Tsampled)).detach().cpu().numpy().mean()))
        self.disc_optimizer.step()
        self.disc_optimizer.zero_grad()


class LGM(nn.Module):
    def __init__(self, args, num_classes, feat_dim , ) -> None:
        super(LGM, self).__init__()
        self.NLL = nn.CrossEntropyLoss()
        self.VC = AdversarialContrastiveLoss(num_classes, feat_dim, alpha=0.0)

        # self.classifier = Discriminators_2l(num_classes=num_classes, feat_dim=feat_dim) if args.disc_layers == 2 else Discriminators_1l(num_classes=num_classes, feat_dim=feat_dim)

        # self.disc_optimizer = torch.optim.Adam( [p for p in self.classifier.parameters()]   ,  lr=1e-3, weight_decay=2e-3)

        self.args = args
    def forward(self,outputs, targets ):
        args = self.args
        logits, self.real, self.sampled, likelihood = self.VC(outputs, label= targets, detach_features=False)
        self.targets = targets

        # from losses import pairwise_dist
        # dist = pairwise_dist(self.criterion.centers, 10)


        # loss_1 = self.cross_entropy(logits, targets)  + args.l1*(likelihood.mean())   #+ 0.001 * (0.2*torch.exp(self.criterion.log_covs).sum() - 640*self.criterion.log_covs.sum()) 
        loss_1 = self.NLL(logits, targets)  + args.l1*(likelihood.mean())  #+ 0.01 * (torch.exp(self.criterion.log_covs).mean() - dist) 
        loss = loss_1 
        return loss, logits



class AdversarialContrastiveLoss(nn.Module):
    """
    Refer to paper:
    Weitao Wan, Yuanyi Zhong,Tianpeng Li, Jiansheng Chen
    Rethinking Feature Distribution for Loss Functions in Image Classification. CVPR 2018
    re-implement by yirong mao
    2018 07/02
    """

    def __init__(self, num_classes, feat_dim , alpha):
        super(AdversarialContrastiveLoss, self).__init__()
        self.feat_dim = feat_dim
        self.num_classes = num_classes
        self.alpha = alpha


        #Theta
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim))
        # self.centers = torch.randn(num_classes, feat_dim).cuda()
        # self.centers = nn.Parameter(torch.cat([torch.eye(num_classes), torch.zeros(num_classes, feat_dim-num_classes)], dim=-1))

        # self.log_covs = nn.Parameter(torch.zeros(num_classes, feat_dim))
        self.log_covs = torch.zeros(num_classes, feat_dim).cuda()



    def forward(self, feat, label=None, detach_features=False):

        batch_size = feat.shape[0]
        feat_dim = feat.shape[1]
        log_covs = torch.unsqueeze(self.log_covs, dim=0)

        covs = torch.exp(log_covs)  # 1*c*d
        tcovs = covs.repeat(batch_size, 1, 1)  # n*c*d
        # try:
        diff = torch.unsqueeze(feat, dim=1) - torch.unsqueeze(self.centers, dim=0)
        # except RuntimeError as e:
        #     print(feat.shape)
        #     print(self.centers.shape)
        #     import pdb
        #     pdb.set_trace()
        wdiff = torch.div(diff, tcovs)
        diff = torch.mul(diff, wdiff)
        dist = torch.sum(diff, dim=-1)  # eq.(18)
        # if label == None:
        #     label = torch.argmin(dist, dim=-1)


        slog_covs = torch.sum(log_covs, dim=-1)  # 1*c
        tslog_covs = slog_covs.repeat(batch_size, 1)
        logits = -0.5 * (tslog_covs + dist)
        likelihood_logits = logits

        if label == None:
            label = torch.argmin(dist, dim=-1)

        if detach_features:            
            diff = torch.unsqueeze(feat.detach(), dim=1) - \
                torch.unsqueeze(self.centers, dim=0)
            wdiff = torch.div(diff, tcovs)
            diff = torch.mul(diff, wdiff)
            dist = torch.sum(diff, dim=-1)  # eq.(18)
            # if label == None:
            #     label = torch.argmin(dist, dim=-1)


            slog_covs = torch.sum(log_covs, dim=-1)  # 1*c
            tslog_covs = slog_covs.repeat(batch_size, 1)
            logits_detached = -0.5 * (tslog_covs + dist)
            likelihood_logits = logits_detached



        # import pdb
        # pdb.set_trace()


        Treal = None
        Tsampled = None
        # if label != None:
        #     Treal = feat 

        #     # try:
        #     distrib = Normal(loc=self.centers[label].reshape(-1), scale = torch.exp(self.log_covs[label]).reshape(-1) )
        #     # except ValueError as E:
        #     #     import pdb
        #     #     pdb.set_trace()
        #     Tsampled = distrib.sample().cuda().reshape(-1, feat_dim)
        #     # Tsampled = feat
        # likelihood = -likelihood_logits[torch.arange(batch_size), label]




        if label != None:
            Treal = feat 

            # try:
            # distrib = MultivariateNormal(loc=self.centers[label], covariance_matrix=torch.eye(feat_dim).repeat(batch_size,1,1).cuda() *  torch.exp(self.log_covs[label]).unsqueeze(-1))
            # import pdb
            # pdb.set_trace()

            # except ValueError as E:
            #     import pdb
            #     pdb.set_trace()

            Tsampled = (torch.Tensor(batch_size, feat_dim).normal_().cuda() * torch.sqrt(torch.exp(self.log_covs[label])) + self.centers[label].cuda())
            # Tsampled2 = (distrib.sample().cuda() * torch.sqrt(torch.exp(self.log_covs[label])) + self.centers[label].cuda())
            # Tsampled = torch.cat([Tsampled1, Tsampled2], dim=0).detach()
 
            # import pdb
            # pdb.set_trace()
        likelihood = -likelihood_logits[torch.arange(batch_size), label]

        return logits, Treal.detach(), Tsampled.detach(), likelihood


In [102]:
from argparse import Namespace
args = Namespace(l1 = 0.1, l2 = 0.1, vc=False, lgm=False, ce=True, disc_layers = 1, dim=2, num_classes=5)

In [62]:
n_train = 1000
n_test = 100
d = args.dim
classes = args.num_classes

means = torch.Tensor(classes, d).normal_()*10
    # .. and shift them around to non-standard Gaussians
covs = torch.ones(classes, d) 


train_x = []
train_y = []

test_x = []
test_y = []


        
for i in range(n_train):
    for label in range(classes):
        
        distrib = torch.Tensor(d).normal_()
    
        s = distrib * torch.sqrt(covs[label, :]) + means[label, :]
        train_y.append(label)
        train_x.append(s)
means += torch.Tensor(classes, d).normal_() 
for i in range(n_test):
    for label in range(classes):
        
        distrib = torch.Tensor(d).normal_()
    
        s = distrib * torch.sqrt(covs[label, :]) + means[label, :]
        test_y.append(label)
        test_x.append(s)

train_x = torch.stack(train_x, dim=0)
train_y = torch.LongTensor(train_y)

test_x = torch.stack(test_x, dim=0)
test_y = torch.LongTensor(test_y)


train_ds = TensorDataset(train_x, train_y)
test_ds = TensorDataset(test_x, test_y)

train_dl = DataLoader(train_ds, batch_size=1024, shuffle=True, num_workers=8)
test_dl = DataLoader(test_ds, batch_size=2048, num_workers = 8)


In [85]:
plt.figure(0)
plt.figure().clear()
plt.close()
plt.cla()
plt.clf()
plt.scatter(x=train_x[:,0], y=train_x[:, 1])
plt.scatter(x=test_x[:,0], y=test_x[:, 1])
plt.savefig("data.png")

In [104]:
class network(nn.Module):
    def __init__(self, dim = 10, num_classes = 100):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(dim, dim), torch.nn.Tanh(), nn.Linear(dim, dim))
        self.lin = nn.Linear(dim, num_classes)
    def forward(self, x):
        return self.net(x)
    
    
class network1(nn.Module):
    def __init__(self, dim = 10, num_classes = 100):
        super().__init__()
        self.net = nn.Linear(dim, dim)
        self.lin = nn.Linear(dim, num_classes)
    def forward(self, x):
        return self.net(x)
    
    
if args.vc or args.lgm:
    net = network(dim=args.dim, num_classes = args.num_classes).cuda()
    model = full_model(network=net, args=args, num_classes=args.num_classes, num_feats=args.dim).cuda()
else:
    model = network(dim=args.dim, num_classes = args.num_classes).cuda()

if args.vc or args.lgm:
    optimizer_2 = torch.optim.SGD([p for p in model.criterion.VC.parameters()],lr=0.01 )
    optimizer_1 = torch.optim.SGD([p for p in model.network.parameters()], lr=0.01)

    criterion = model.criterion
else:
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer_2 = None
    optimizer_1 = torch.optim.Adam([p for p in model.parameters()])

    



In [107]:
curr_loss = None
args.vis = True
for epoch_no in range(50):
    all_preds = []
    all_targets = []
    confidences = []
    feats = []
    for id_, (x, y) in enumerate(train_dl):
        x = x.cuda()
        y = y.cuda()
        z = model(x)
        if args.vis:
            feats.append(z.detach().cpu().numpy())

        if args.ce:
            logits = model.lin(z)
        
        

        if args.vc or args.lgm:
            loss, logits = criterion(z, y)
        else:
            loss = criterion(logits, y)
            
        _, preds = logits.max(-1)
        
        confs, _ = torch.max(F.softmax(logits, dim=-1), -1)
        confidences.extend(confs.detach().cpu().numpy().tolist())


        # Backward pass and update

        optimizer_1.zero_grad()
        if args.vc or args.lgm:
            optimizer_2.zero_grad()
        loss.backward()
        optimizer_1.step()
        if args.vc or args.lgm:
            optimizer_2.step()

        if args.vc or args.lgm:
            optimizer_1.zero_grad()
            optimizer_2.zero_grad()
            if args.vc:
                criterion.discriminator_train()
            optimizer_1.zero_grad()
            optimizer_2.zero_grad()

        
#         loss = criterion(logits, y).mean()
        
        all_preds.extend(preds.detach().cpu().numpy().tolist())
        all_targets.extend(y.cpu().numpy().tolist())
        

        loss_ = loss.item()
        if curr_loss == None:
            curr_loss = loss_
        else:
            curr_loss = curr_loss * 0.99 + loss_ * 0.01
        if id_ % 100 == 0:
            print(curr_loss)
            

    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    accuracies = (all_preds == all_targets ).tolist()

    ECE, MCE, _, _, _ =  Evaluator.create_fine_bins(accuracies, confidences)
    ECE = ECE * 100
    

    acc = (all_preds == all_targets ).mean()
    print("Accuracy: {}".format(acc))
    print("ECE: {}".format(ECE))
    
    
    

1.6190098524093628
Accuracy: 0.197
ECE: 17.6784133797884
1.6185319493001387
Accuracy: 0.1974
ECE: 17.648651630580424
1.6180991704585437
Accuracy: 0.1978
ECE: 17.635757595598697
1.6174373027503184
Accuracy: 0.198
ECE: 17.58661621540785
1.6166732802086468
Accuracy: 0.198
ECE: 17.526279639303684
1.6156846222330836
Accuracy: 0.1982
ECE: 17.482070382535458
1.6147540597563284
Accuracy: 0.1982
ECE: 17.40840059131384
1.6136342882083965
Accuracy: 0.1984
ECE: 17.358012440204618
1.612353584463792
Accuracy: 0.1986
ECE: 17.3059077963233
1.6110255340281825
Accuracy: 0.199
ECE: 17.267948443591592
1.6095873513470003
Accuracy: 0.199
ECE: 17.188887337744234
1.607976194656722
Accuracy: 0.199
ECE: 17.09521143645048
1.6062903605234082
Accuracy: 0.199
ECE: 17.00354207187891
1.6043402350628486
Accuracy: 0.199
ECE: 16.90524053543806
1.6024046887099876
Accuracy: 0.199
ECE: 16.80406642705202
1.6003155324349283
Accuracy: 0.1992
ECE: 16.71880835771561
1.597965437197662
Accuracy: 0.1992
ECE: 16.604429086148738
1.5

In [106]:
feats = np.concatenate(feats, axis=0)

plt.figure(0)
plt.figure().clear()
plt.close()
plt.cla()
plt.clf()
plt.scatter(x=feats[:,0], y=feats[:, 1])
plt.scatter(x=feats[:,0], y=feats[:, 1])
plt.savefig("z.png")

In [46]:
all_preds = []
all_targets = []
confidences = []
for id_, (x, y) in enumerate(test_dl):
    x = x.cuda()
    y = y.cuda()
    z = model(x)


    if args.ce:
        logits = model.lin(z)



    if args.vc or args.lgm:
        loss, logits = criterion(z, y)
    else:
        loss = criterion(logits, y)

    _, preds = logits.max(-1)

    
    
    confs, _ = torch.max(F.softmax(logits, dim=-1), -1)
    confidences.extend(confs.detach().cpu().numpy().tolist())
    
    all_preds.extend(preds.detach().cpu().numpy().tolist())
    all_targets.extend(y.cpu().numpy().tolist())

    loss_ = loss.item()

all_preds = np.array(all_preds)
all_targets = np.array(all_targets)

acc = (all_preds == all_targets ).mean()
accuracies = (all_preds == all_targets ).tolist()

ECE, MCE, _, _, _ =  Evaluator.create_fine_bins(accuracies, confidences)
ECE = ECE * 100


acc = (all_preds == all_targets ).mean()
print("Accuracy: {}".format(acc))
print("ECE: {}".format(ECE))

Accuracy: 0.54684
ECE: 8.558109813943506


In [None]:
all_targets.shape

In [None]:

plt.figure(0)
plt.figure().clear()
plt.close()
plt.cla()
plt.clf()
plt.scatter(x=torch.stack(samples, dim=0)[:, 0],y = torch.stack(samples, dim=0)[:, 1] )
plt.savefig("data.png")

# plot(torch.stack(samples, dim=0), torch.LongTensor(y))

