# DOC and DOC3 results for tile data
# Table 5 and 6 (correlation) results

## Imports

In [None]:
import os
import time
import torch
import logging
import torchvision
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

import torchvision.transforms as transforms

import mvtec as md   

import warnings
warnings.filterwarnings('ignore')

device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
    
print(device)

### For code optimization

In [None]:
# uncomment for code optimization
#torch.backends.cudnn.benchmark = True

NUM_WORKERS = int(os.cpu_count() / 2)
print('number of workers: ', NUM_WORKERS)

## 1. Helper Functions

### 1a. Evaluation (Loss, Correlations, Eval Metrics)

In [None]:
# Loss Function
class HingeLoss(nn.Module):
    def __init__(self):
        super(HingeLoss,self).__init__()
    
    def forward(self, ypred, ytrue, margin = 1.0, smooth = False):
        ypred = ypred.squeeze()
        if smooth:
            loss = torch.nn.Softplus()
            out = torch.mean(loss(margin - (ytrue * ypred)))
        else:
            out = torch.mean(torch.relu(margin - (ytrue * ypred)))
        return out


def SigmaTh2(trn_dat, univ_dat, net, feat): # Complete Data (not data loaders)
    eps = 1e-9
    with torch.no_grad():
        
        trn_dat = trn_dat.to(device).float()
        univ_dat = univ_dat.to(device).float()
        
        Z = net(trn_dat, feat = feat) # 'cnn' and 'final'
        U = net(univ_dat, feat = feat)

        Z = Z.data.cpu().numpy()
        U = U.data.cpu().numpy()
        a = np.array([1,-1]).reshape(2,1)
        V = np.kron(U,a)
        
        VZT = np.dot(V,Z.T)
        ZVT = np.dot(Z,V.T)
        
        sig_num = np.trace(np.dot(VZT,ZVT))
        ZTZ = np.trace(np.dot(Z.T,Z))
        VVT = np.dot(V,V.T)
        VVTr = np.trace(VVT) 
        I_VVT = VVT.shape[0] + VVTr
        sig_den = (ZTZ*VVTr)
        sig_den2 = (ZTZ*I_VVT)
        
    return (sig_num/sig_den, sig_num/sig_den2, sig_num, sig_den, sig_den2, ZTZ, VVTr, I_VVT)


def evalmetrics(y_true,scores):
    auc_score = roc_auc_score(y_true, scores)
    
    return auc_score

## 2. Models

![LeNET](LeNET.png)

In [None]:
class CIFAR10_LeNet(nn.Module):

    def __init__(self):
        super(CIFAR10_LeNet, self).__init__()

        self.rep_dim = 128
        self.pool = nn.MaxPool2d(2, 2)

        self.conv1 = nn.Conv2d(3, 32, 5, bias=False, padding=2)
        self.bn2d1 = nn.BatchNorm2d(32, eps=1e-04, affine=False)
        self.conv2 = nn.Conv2d(32, 64, 5, bias=False, padding=2)
        self.bn2d2 = nn.BatchNorm2d(64, eps=1e-04, affine=False)
        self.conv3 = nn.Conv2d(64, 128, 5, bias=False, padding=2)
        self.bn2d3 = nn.BatchNorm2d(128, eps=1e-04, affine=False)
        self.fc1 = nn.Linear(128 * 8 * 8, self.rep_dim, bias=False)
        self.fc2 = nn.Linear(self.rep_dim, int(self.rep_dim/2), bias=False)
        self.fc3 = nn.Linear(int(self.rep_dim/2), 1, bias=False)

    def forward(self, x, feat = 'default'):
        
        if feat in 'none':
            return x.view(x.size(0), -1)
        
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn2d1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2d2(x)))
        x = self.conv3(x)
        x = self.pool(F.leaky_relu(self.bn2d3(x)))
        x = x.view(x.size(0), -1)
        
        if feat in 'cnn':
            return x
        
        x = F.leaky_relu(self.fc1(x))
        x = F.leaky_relu(self.fc2(x))
        
        if feat in 'final':
            return x
        
        x = self.fc3(x)
        return x

## 3. Data

##############################################################################################
### To use our data loader please download all the MVTec data available at:- 
### https://www.mvtec.com/company/research/datasets/mvtec-ad
### And save them in the folder ./mvtec
##############################################################################################

### 3.1 Data loaders

In [None]:
IMSHAPE = 64
BATCH_SIZE = 100

transform = transforms.Compose(
    [transforms.ToTensor()])


##############################################################################################
## To use our data loader please download all the MVTec data available at:- 
## https://www.mvtec.com/company/research/datasets/mvtec-ad
## And save them in the folder ./mvtec
##
###############################################################################################
#                                     Train Data
###############################################################################################

train_dat = md.MVTEC(root="./mvtec", train=True, transform=transform, resize=IMSHAPE, interpolation=3, category='tile')
train_loader = torch.utils.data.DataLoader(train_dat, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=NUM_WORKERS)


###############################################################################################
#                                     Evaluations on :-
###############################################################################################

# Test Data
test_dat = md.MVTEC(root="./mvtec", train=False, transform=transform,resize=IMSHAPE, interpolation=3, category='tile')
test_loader = torch.utils.data.DataLoader(dataset=test_dat, batch_size=117, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)

# Train (used to compute correlations)
train_loader_eval = torch.utils.data.DataLoader(train_dat, batch_size=280, shuffle=False, pin_memory=True, num_workers=NUM_WORKERS)


# Objects Eval
objs = ['bottle', 'cable', 'capsule', 'hazelnut', 'metal_nut', 'pill', 'transistor']
ntrn = [209,224,219,391,220,267,213]
ntst = [83,150,132,110,115,167,100]

obj_data_loader_eval = list()
obj_data_tst_loader_eval = list()

for (i,obj) in enumerate(objs):
    obj_dat = md.MVTEC(root="./mvtec", train=True, transform=transform, resize=IMSHAPE, interpolation=3, category=obj)
    obj_data_loader_eval.append(torch.utils.data.DataLoader(obj_dat, batch_size=ntrn[i], shuffle=True, pin_memory=True, num_workers=NUM_WORKERS))

    obj_dat_tst = md.MVTEC(root="./mvtec", train=False, transform=transform, resize=IMSHAPE, interpolation=3, category=obj)
    obj_data_tst_loader_eval.append(torch.utils.data.DataLoader(obj_dat_tst, batch_size=ntst[i], shuffle=True, pin_memory=True, num_workers=NUM_WORKERS))
    

    
# Texture Eval
texts = ['carpet', 'leather', 'wood']
ntrn = [280,245,247]
ntst = [117,124,79]

text_data_loader_eval = list()
text_data_tst_loader_eval = list()

for (i,text) in enumerate(texts):
    text_dat = md.MVTEC(root="./mvtec", train=True, transform=transform, resize=IMSHAPE, interpolation=3, category=text)
    text_data_loader_eval.append(torch.utils.data.DataLoader(text_dat, batch_size=ntrn[i], shuffle=True, pin_memory=True, num_workers=NUM_WORKERS))

    text_dat_tst = md.MVTEC(root="./mvtec", train=False, transform=transform, resize=IMSHAPE, interpolation=3, category=text)
    text_data_tst_loader_eval.append(torch.utils.data.DataLoader(text_dat_tst, batch_size=ntst[i], shuffle=True, pin_memory=True, num_workers=NUM_WORKERS))

## 4. Methods

In [None]:
# number of repetitions per experiment
# Results reported in paper are the summary statistics over 10 repetitions (i.e., n_reps = 10)
n_reps = 1

### DOC (HINGE)

In [None]:
EPOCH = 1000

logging.basicConfig(filename='MVTec_tile_DOC.log', level=logging.INFO)

for lam in [0.01]:
    for lr in [1e-5]:
            
        eps = 0.0
        dpos = 1.0-eps
        dneg = -(1.0+eps)
            
        print('================================')
        print('lam = {}, lr = {}'.format(lam,lr))
            

        logging.info('================================')
        logging.info('lam = {}, lr = {}'.format(lam,lr))

        for repetition in range(n_reps):

            print('--------------------------------')
            print('repetition number: ')
            print(repetition)

            logging.info('--------------------------------')
            logging.info('repetition number: ')
            logging.info(repetition)

            net = CIFAR10_LeNet()
            net = net.to(device)

            optimizer = torch.optim.Adam(net.parameters(), lr = lr)  
            train_loss = HingeLoss()

            for epoch in range(EPOCH+1):
                    
                loss_epoch = 0.0
                n_batches = 0
                epoch_start_time = time.time()

                for (i,data) in enumerate(train_loader):
                    inputs, labels = data
                    inputs = inputs.to(device)
                    ytr = np.ones((len(inputs),1))
                    ytr = torch.from_numpy(ytr).to(device).float()
                    outputs = net(inputs) 

                    # Zero the network parameter gradients
                    optimizer.zero_grad()

                    # Loss
                    loss = train_loss(outputs,ytr,smooth=True)

                    lam = torch.tensor(lam).to(device)
                    l2_reg = torch.tensor(0.).to(device)

                    for name, param in net.named_parameters():
                        l2_reg += torch.norm(param)**2

                    loss += lam * l2_reg


                    loss.backward()
                    optimizer.step()

                    loss_epoch += loss.item()
                    n_batches += 1

                if epoch % 200 == 0:
                    idx_label_score = []
                    net.eval()
                    with torch.no_grad():

                        for data in test_loader:
                            inputs, labels = data

                            labels = labels.cpu().data.numpy()
                                
                            inputs = inputs.to(device)
                            outputs = net(inputs)
                            scores = outputs.data.cpu().numpy()
                            idx_label_score += list(zip(labels.tolist(),
                                                        scores.tolist()))

                    tstlabels, scores = zip(*idx_label_score)
                    tstlabels = np.array(tstlabels)
                    tstlabels[np.where(tstlabels==0)]=-1.0
                    scores = np.array(scores)
                    Tst_auc_score = evalmetrics(tstlabels,scores.flatten())   

                    # log epoch statistics
                    epoch_train_time = time.time() - epoch_start_time
                    print('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                                .format(epoch , EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))
                    logging.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                                .format(epoch , EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))

            # Final Solution
            print(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                    .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))
            logging.info(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                    .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))
                
                
            # CORRELATION VALUES
            sampled_data = iter(train_loader_eval)
            trinputs,_= sampled_data.next()
                
                
            # UNIV NOISE
            uinputs = torch.from_numpy(np.random.rand(trinputs.shape[0], 3, 64, 64)) 
            CORR_vals0 = SigmaTh2(trinputs, uinputs , net, feat = 'none')
            CORR_vals1 = SigmaTh2(trinputs, uinputs , net, feat = 'cnn')
            print('Correlation Vals (NOISE) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
            logging.info('Correlation Vals (NOISE) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
                
            # UNIV OBJs
            for (i,obj) in enumerate(objs): 
                sampled_data = iter(obj_data_loader_eval[i])
                XU1, _ = sampled_data.next()
                sampled_data = iter(obj_data_tst_loader_eval[i])
                XU2,_ = sampled_data.next()
                if i==0:
                    uinputs = torch.cat((XU1,XU2))
                else:
                    uinputs = torch.cat((uinputs,XU1,XU2))
                    
            CORR_vals0 = SigmaTh2(trinputs, uinputs , net, feat = 'none')
            CORR_vals1 = SigmaTh2(trinputs, uinputs , net, feat = 'cnn')
            print('Correlation Vals (OBJECTS) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
            logging.info('Correlation Vals (OBJECTS) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
                
            # UNIV TEXT
            for (i,text) in enumerate(texts): 
                sampled_data = iter(text_data_loader_eval[i])
                XU1, _ = sampled_data.next()
                sampled_data = iter(text_data_tst_loader_eval[i])
                XU2,_ = sampled_data.next()
                if i==0:
                    uinputs = torch.cat((XU1,XU2))
                else:
                    uinputs = torch.cat((uinputs,XU1,XU2))
                
            CORR_vals0 = SigmaTh2(trinputs, uinputs , net, feat = 'none')
            CORR_vals1 = SigmaTh2(trinputs, uinputs , net, feat = 'cnn')
            print('Correlation Vals (Texture) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
            logging.info('Correlation Vals (Texture) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
                
            print('----------------------------------------------------------------------------------------------')

### DOC3 (Noise Universum)

In [None]:
EPOCH = 1000

logging.basicConfig(filename='MVTec_tile_DOC3_Univ_noise.log', level=logging.INFO)

for lam in [0.1]:
    for lr in [5e-6]:
        for Cu in [0.1]:
            
            eps = 0.0
            dpos = 1.0-eps
            dneg = -(1.0+eps)
            
            print('================================')
            print('lam = {}, lr = {}, Cu = {}'.format(lam,lr,Cu))
            

            logging.info('================================')
            logging.info('lam = {}, lr = {}, Cu = {}'.format(lam,lr,Cu))

            for repetition in range(n_reps):

                print('--------------------------------')
                print('repetition number: ')
                print(repetition)

                logging.info('--------------------------------')
                logging.info('repetition number: ')
                logging.info(repetition)

                net = CIFAR10_LeNet()
                net = net.to(device)

                optimizer = torch.optim.Adam(net.parameters(), lr = lr)  
                train_loss = HingeLoss()
                unlabeled_posloss = HingeLoss()
                unlabeled_negloss = HingeLoss()

                for epoch in range(EPOCH+1):
                    
                    loss_epoch = 0.0
                    n_batches = 0
                    epoch_start_time = time.time()

                    for (i,data) in enumerate(train_loader):
                        inputs, labels = data
                        inputs = inputs.to(device)
                        ytr = np.ones((len(inputs),1))
                        ytr = torch.from_numpy(ytr).to(device).float()
                        outputs = net(inputs) 
                        
                        XU = torch.from_numpy(np.random.rand(BATCH_SIZE, 3, 64, 64))
                                                
                        XU = XU.to(device).float()
                        outputsU = net(XU)
                        yunppos = np.ones(BATCH_SIZE)
                        yunpneg = -np.ones(BATCH_SIZE)
                        yupos = torch.from_numpy(yunppos).to(device).float()
                        yuneg = torch.from_numpy(yunpneg).to(device).float()

                        # Zero the network parameter gradients
                        optimizer.zero_grad()

                        # Loss
                        loss = train_loss(outputs,ytr,smooth=True)

                        # Unlabeled data
                        lossunlab = unlabeled_posloss(outputsU,yupos,dpos) + unlabeled_negloss(outputsU,yuneg,dneg)
                        loss+=Cu*lossunlab

                        lam = torch.tensor(lam).to(device)
                        l2_reg = torch.tensor(0.).to(device)

                        for name, param in net.named_parameters():
                            l2_reg += torch.norm(param)**2

                        loss += lam * l2_reg

                        loss.backward()
                        optimizer.step()

                        loss_epoch += loss.item()
                        n_batches += 1

                    if epoch % 200 == 0:
                        idx_label_score = []
                        net.eval()
                        with torch.no_grad():

                            for data in test_loader:
                                inputs, labels = data

                                labels = labels.cpu().data.numpy()
                                
                                inputs = inputs.to(device)
                                outputs = net(inputs)
                                scores = outputs.data.cpu().numpy()
                                idx_label_score += list(zip(labels.tolist(),
                                                            scores.tolist()))

                        tstlabels, scores = zip(*idx_label_score)
                        tstlabels = np.array(tstlabels)
                        tstlabels[np.where(tstlabels==0)]=-1.0
                        scores = np.array(scores)
                        Tst_auc_score = evalmetrics(tstlabels,scores.flatten()) 

                        # log epoch statistics
                        epoch_train_time = time.time() - epoch_start_time
                        print('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                                    .format(epoch , EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))
                        logging.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                                    .format(epoch , EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))

                # Final Solution
                print(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                        .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))
                logging.info(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                        .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))
                
                
                # CORRELATION VALUES
                sampled_data = iter(train_loader_eval)
                trinputs,_= sampled_data.next()
                
                
                # UNIV NOISE
                uinputs = torch.from_numpy(np.random.rand(trinputs.shape[0], 3, 64, 64)) 
                CORR_vals0 = SigmaTh2(trinputs, uinputs , net, feat = 'none')
                CORR_vals1 = SigmaTh2(trinputs, uinputs , net, feat = 'cnn')
                print('Correlation Vals (NOISE) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
                logging.info('Correlation Vals (NOISE) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
                
                
                print('----------------------------------------------------------------------------------------------')

### DOC3 (Objects Universum)

In [None]:
univ_data_loader = list()
univ_data_tst_loader = list()

for (i,obj) in enumerate(objs):
    univ_dat = md.MVTEC(root="./mvtec", train=True, transform=transform, resize=IMSHAPE, interpolation=3,category=obj)
    univ_data_loader.append(torch.utils.data.DataLoader(univ_dat, batch_size=int(BATCH_SIZE/2), shuffle=True, pin_memory=True, num_workers=NUM_WORKERS))

    univ_dat_tst = md.MVTEC(root="./mvtec", train=False, transform=transform, resize=IMSHAPE, interpolation=3,category=obj)
    univ_data_tst_loader.append(torch.utils.data.DataLoader(univ_dat_tst, batch_size=int(BATCH_SIZE/2), shuffle=True, pin_memory=True, num_workers=NUM_WORKERS))

In [None]:
EPOCH = 1000

logging.basicConfig(filename='MVTec_tile_DOC3_Univ_objects.log', level=logging.INFO)

for lam in [0.005]:
    for lr in [1e-4]:
        for Cu in [2.0]:
            
            eps = 0.0
            dpos = 1.0-eps
            dneg = -(1.0+eps)
            
            print('================================')
            print('lam = {}, lr = {}, Cu = {}'.format(lam,lr,Cu))
            

            logging.info('================================')
            logging.info('lam = {}, lr = {}, Cu = {}'.format(lam,lr,Cu))

            for repetition in range(n_reps):

                print('--------------------------------')
                print('repetition number: ')
                print(repetition)

                logging.info('--------------------------------')
                logging.info('repetition number: ')
                logging.info(repetition)

                net = CIFAR10_LeNet()
                net = net.to(device)

                optimizer = torch.optim.Adam(net.parameters(), lr = lr)  
                train_loss = HingeLoss()
                unlabeled_posloss = HingeLoss()
                unlabeled_negloss = HingeLoss()

                for epoch in range(EPOCH+1):
                    
                    loss_epoch = 0.0
                    n_batches = 0
                    epoch_start_time = time.time()

                    for (i,data) in enumerate(train_loader):
                        inputs, labels = data
                        inputs = inputs.to(device)
                        ytr = np.ones((len(inputs),1))
                        ytr = torch.from_numpy(ytr).to(device).float()
                        outputs = net(inputs) 
                        
                        uint = np.random.randint(0, 7)
                        
                        sampled_data = iter(univ_data_loader[uint])
                        XU1, _ = sampled_data.next()
                        sampled_data = iter(univ_data_tst_loader[uint])
                        XU2,_ = sampled_data.next()
                        XU = torch.cat((XU1,XU2))
                                              
                        XU = XU.to(device).float()
                        outputsU = net(XU)
                        yunppos = np.ones(BATCH_SIZE)
                        yunpneg = -np.ones(BATCH_SIZE)
                        yupos = torch.from_numpy(yunppos).to(device).float()
                        yuneg = torch.from_numpy(yunpneg).to(device).float()

                        # Zero the network parameter gradients
                        optimizer.zero_grad()

                        # Loss
                        loss = train_loss(outputs,ytr,smooth = True)

                        # Unlabeled data
                        lossunlab = unlabeled_posloss(outputsU,yupos,dpos) + unlabeled_negloss(outputsU,yuneg,dneg)
                        loss+=Cu*lossunlab

                        lam = torch.tensor(lam).to(device)
                        l2_reg = torch.tensor(0.).to(device)

                        for name, param in net.named_parameters():
                            l2_reg += torch.norm(param)**2

                        loss += lam * l2_reg
                        loss.backward()
                        optimizer.step()

                        loss_epoch += loss.item()
                        n_batches += 1

                    if epoch % 200 == 0:
                        idx_label_score = []
                        net.eval()
                        with torch.no_grad():

                            for data in test_loader:
                                inputs, labels = data

                                labels = labels.cpu().data.numpy()
                                
                                inputs = inputs.to(device)
                                outputs = net(inputs)
                                scores = outputs.data.cpu().numpy()
                                idx_label_score += list(zip(labels.tolist(),
                                                            scores.tolist()))

                        tstlabels, scores = zip(*idx_label_score)
                        tstlabels = np.array(tstlabels)
                        tstlabels[np.where(tstlabels==0)]=-1.0
                        scores = np.array(scores)
                        Tst_auc_score = evalmetrics(tstlabels,scores.flatten())
                        
                        # log epoch statistics
                        epoch_train_time = time.time() - epoch_start_time
                        print('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                                    .format(epoch , EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))
                        logging.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                                    .format(epoch , EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))

                # Final Solution
                print(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                        .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))
                logging.info(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                        .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))

                
                # CORRELATION VALUES
                sampled_data = iter(train_loader_eval)
                trinputs,_= sampled_data.next()
                
                
                
                # UNIV OBJs
                for (i,obj) in enumerate(objs): 
                    sampled_data = iter(obj_data_loader_eval[i])
                    XU1, _ = sampled_data.next()
                    sampled_data = iter(obj_data_tst_loader_eval[i])
                    XU2,_ = sampled_data.next()
                    if i==0:
                        uinputs = torch.cat((XU1,XU2))
                    else:
                        uinputs = torch.cat((uinputs,XU1,XU2))
                    
                CORR_vals0 = SigmaTh2(trinputs, uinputs , net, feat = 'none')
                CORR_vals1 = SigmaTh2(trinputs, uinputs , net, feat = 'cnn')
                print('Correlation Vals (OBJECTS) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
                logging.info('Correlation Vals (OBJECTS) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))               
                
                print('----------------------------------------------------------------------------------------------')

### DOC3 (Textures Universum)

In [None]:
univ_data_loader = list()
univ_data_tst_loader = list()

for (i,txt) in enumerate(texts):
    univ_dat = md.MVTEC(root="./mvtec", train=True, transform=transform, resize=IMSHAPE, interpolation=3,category=txt)
    univ_data_loader.append(torch.utils.data.DataLoader(univ_dat, batch_size=int(BATCH_SIZE/2), shuffle=True, pin_memory=True, num_workers=NUM_WORKERS))

    univ_dat_tst = md.MVTEC(root="./mvtec", train=False, transform=transform, resize=IMSHAPE, interpolation=3,category=txt)
    univ_data_tst_loader.append(torch.utils.data.DataLoader(univ_dat_tst, batch_size=int(BATCH_SIZE/2), shuffle=True, pin_memory=True, num_workers=NUM_WORKERS))

In [None]:
EPOCH = 1000 

logging.basicConfig(filename='MVTec_tile_DOC3_Univ_textures.log', level=logging.INFO)

for lam in [0.1]:
    for lr in [5e-6]:
        for Cu in [0.1]:
            
            eps = 0.0
            dpos = 1.0-eps
            dneg = -(1.0+eps)
            
            print('================================')
            print('lam = {}, lr = {}, Cu = {}'.format(lam,lr,Cu))
            

            logging.info('================================')
            logging.info('lam = {}, lr = {}, Cu = {}'.format(lam,lr,Cu))

            for repetition in range(n_reps):

                print('--------------------------------')
                print('repetition number: ')
                print(repetition)

                logging.info('--------------------------------')
                logging.info('repetition number: ')
                logging.info(repetition)

                net = CIFAR10_LeNet()
                net = net.to(device)

                optimizer = torch.optim.Adam(net.parameters(), lr = lr)  
                train_loss = HingeLoss()
                unlabeled_posloss = HingeLoss()
                unlabeled_negloss = HingeLoss()

                for epoch in range(EPOCH+1):
                    
                    loss_epoch = 0.0
                    n_batches = 0
                    epoch_start_time = time.time()

                    for (i,data) in enumerate(train_loader):
                        inputs, labels = data
                        inputs = inputs.to(device)
                        ytr = np.ones((len(inputs),1))
                        ytr = torch.from_numpy(ytr).to(device).float()
                        outputs = net(inputs) 
                        
                        uint = np.random.randint(0, 3)
                        
                        sampled_data = iter(univ_data_loader[uint])
                        XU1, _ = sampled_data.next()
                        sampled_data = iter(univ_data_tst_loader[uint])
                        XU2,_ = sampled_data.next()
                        XU = torch.cat((XU1,XU2))
                                             
                        XU = XU.to(device).float()
                        outputsU = net(XU)
                        yunppos = np.ones(BATCH_SIZE)
                        yunpneg = -np.ones(BATCH_SIZE)
                        yupos = torch.from_numpy(yunppos).to(device).float()
                        yuneg = torch.from_numpy(yunpneg).to(device).float()

                        # Zero the network parameter gradients
                        optimizer.zero_grad()

                        # Loss
                        loss = train_loss(outputs,ytr,smooth = True)

                        # Unlabeled data
                        lossunlab = unlabeled_posloss(outputsU,yupos,dpos) + unlabeled_negloss(outputsU,yuneg,dneg)
                        loss+=Cu*lossunlab

                        lam = torch.tensor(lam).to(device)
                        l2_reg = torch.tensor(0.).to(device)

                        for name, param in net.named_parameters():
                            l2_reg += torch.norm(param)**2

                        loss += lam * l2_reg
                        loss.backward()
                        optimizer.step()

                        loss_epoch += loss.item()
                        n_batches += 1

                    if epoch % 200 == 0:
                        idx_label_score = []
                        net.eval()
                        with torch.no_grad():

                            for data in test_loader:
                                inputs, labels = data

                                labels = labels.cpu().data.numpy()

                                inputs = inputs.to(device)
                                outputs = net(inputs)
                                scores = outputs.data.cpu().numpy()
                                idx_label_score += list(zip(labels.tolist(),
                                                            scores.tolist()))

                        tstlabels, scores = zip(*idx_label_score)
                        tstlabels = np.array(tstlabels)
                        tstlabels[np.where(tstlabels==0)]=-1.0
                        scores = np.array(scores)
                        Tst_auc_score = evalmetrics(tstlabels,scores.flatten())
                        
                        # log epoch statistics
                        epoch_train_time = time.time() - epoch_start_time
                        print('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                                    .format(epoch , EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))
                        logging.info('  Epoch {}/{}\t Time: {:.3f}\t Loss: {:.8f} \t Test AUC :{:.4f} '
                                    .format(epoch , EPOCH, epoch_train_time, loss_epoch / n_batches, Tst_auc_score))

                # Final Solution
                print(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                        .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))
                logging.info(' Final (TOT. EPOCHS {})::  Loss: {:.8f} \t   AUC (Test) :{:.4f} '
                        .format(EPOCH, loss_epoch / n_batches, Tst_auc_score))
                
                
                # CORRELATION VALUES
                sampled_data = iter(train_loader_eval)
                trinputs,_= sampled_data.next()
                 
                # UNIV TEXT
                for (i,text) in enumerate(texts): 
                    sampled_data = iter(text_data_loader_eval[i])
                    XU1, _ = sampled_data.next()
                    sampled_data = iter(text_data_tst_loader_eval[i])
                    XU2,_ = sampled_data.next()
                    if i==0:
                        uinputs = torch.cat((XU1,XU2))
                    else:
                        uinputs = torch.cat((uinputs,XU1,XU2))
                
                CORR_vals0 = SigmaTh2(trinputs, uinputs , net, feat = 'none')
                CORR_vals1 = SigmaTh2(trinputs, uinputs , net, feat = 'cnn')
                print('Correlation Vals (Texture) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
                logging.info('Correlation Vals (Texture) = {}, {}'.format(CORR_vals0[0], CORR_vals1[0]))
                
                
                print('----------------------------------------------------------------------------------------------')