In [1]:
import torch
import torch.nn as nn
import torchvision
import tqdm
import time
import math
import torchvision.models as models
import torchvision.transforms as transforms

In [2]:
from gpytorch.variational import VariationalStrategy
from gpytorch.models.deep_gps import DeepGPLayer, DeepGP

In [3]:
import sys
import urllib.request
import os
from scipy.io import loadmat
from math import floor
import pandas as pd
import numpy as np

if torch.cuda.is_available():
    device = 'cuda:6' # change here to tune the device we use
else:
    device = 'cpu'
# device = 'cpu'

# load data and pre-processing
# regression: UCI dataset, classification: Image dataset
dataset = 'cifar10' # set here to determine to train which dataset
if dataset == 'red_wine':
    # L=6, D=11, N=1599
    data = pd.read_csv('data/red_wine.csv').values
    X = torch.Tensor(data[:, :11])
    y = torch.Tensor(data[:, 11]).add(-3).long()
elif dataset == 'white_wine':
    # L=7, D=11, N=4898
    data = pd.read_csv('data/white_wine.csv').values
    X = torch.Tensor(data[:, :11])
    y = torch.Tensor(data[:, 11]).add(-3).long()
elif dataset == 'cifar10':
    # L=10, C=3. H=32, W=32
    pass
elif dataset == 'mnist':
    # L=10, C=1, H=28, W=28, label: 0-9
    pass
elif dataset == 'fashion_mnist':
    # L=10, C=1, H=28, W=28, label: 0-9
    pass
elif dataset == 'abalone':
    # L=29, D=7, N=4177
    data = pd.read_excel('data/abalone.xlsx', header=None).values
    X = torch.Tensor(data[:, :8])
    y = torch.Tensor(data[:, 8]).add(-1).long()
elif dataset == 'wilt':
    # L=2, D=5, N=4840
    train_data = pd.read_csv('data/wilt_training.csv').values
    test_data = pd.read_csv('data/wilt_test.csv').values
    train_x = train_data[:, 1:]
    test_x = test_data[:, 1:]
    train_x = train_x - train_x.min(0)[0]  # X.min(0)[0]: min value of every feature
    train_x = 2 * (train_x / train_x.max(0)[0]) - 1 # pre-preocess to [-1,1]
    test_x = test_x - test_x.min(0)[0]  # X.min(0)[0]: min value of every feature
    test_x = 2 * (test_x / test_x.max(0)[0]) - 1 # pre-preocess to [-1,1]
    
    train_x = torch.Tensor(train_x).contiguous()
    test_x = torch.Tensor(test_x).contiguous()
    train_y = torch.Tensor(train_data[:, 0]).long().contiguous()
    test_y = torch.Tensor(test_data[:, 0]).long().contiguous()
    
# pre-processing
if dataset == 'cifar10':
    transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

elif dataset == 'mnist':
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Normalize((0.5), (0.5))
    ])

elif dataset == 'fashion_mnist':
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Normalize((0.5), (0.5))
    ])
'''    
elif dataset == 'wilt':
    pass
elif dataset == 'red_wine' or 'white_wine' or 'abalone':
    X = X - X.min(0)[0]  # X.min(0)[0]: min value of every feature
    X = 2 * (X / X.max(0)[0]) - 1 # pre-preocess to [-1,1]

    # split data
    ratio = 0.9 # ratio to split train set
    train_n = int(floor(ratio * len(X))) # split train set, ratio: 0.8
    train_x = X[:train_n, :].contiguous()
    train_y = y[:train_n].contiguous()

    test_x = X[train_n:, :].contiguous()
    test_y = y[train_n:].contiguous()
        
# move data to cuda or cpu

if dataset == 'red_wine' or 'white_wine' or 'abalone' or 'wilt':
    train_x, train_y, test_x, test_y = train_x.to(device), train_y.to(device), test_x.to(device), test_y.to(device)

'''
def get_parameter_number(model):
    total_num = sum(p.numel() for p in model.parameters())
    trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

In [4]:
from torch.utils.data import TensorDataset, DataLoader
if dataset == 'red_wine':
    loader_batch_size = 512
elif dataset == 'white_wine':
    loader_batch_size = 512
elif dataset == 'cifar10':
    loader_batch_size = 512
elif dataset == 'mnist':
    loader_batch_size = 128
elif dataset == 'fashion_mnist':
    loader_batch_size = 256
elif dataset == 'abalone':
    loader_batch_size = 1024
elif dataset == 'wilt':
    loader_batch_size = 1024

if dataset == 'cifar10':
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, 
                                                 transform=transform_train)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, 
                                                 transform=transform_test)
elif dataset == 'mnist':
    train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=False, 
                                                 transform=transform)
    test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=False, 
                                                 transform=transform)
elif dataset == 'fashion_mnist':
    train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, 
                                                 transform=transform)
    test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, 
                                                 transform=transform)
'''
elif dataset == 'red_wine' or 'white_wine' or 'abalone' or 'wilt':
    train_dataset = TensorDataset(train_x, train_y)
    test_dataset = TensorDataset(test_x, test_y)
'''
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256)
 

Files already downloaded and verified
Files already downloaded and verified


In [5]:
train_loader1 = DataLoader(train_dataset, batch_size=256, shuffle=True)

In [6]:
print_time = False # set True to print module time
print_norm = False # set True to print norm of f_u (the perfect value of norm after training: 0)
print_metric = True # set True to print metric per epoch
print_param = True # set True to print parameter size of each network
print_loss = False # set True to print each loss during training
sample_times = 1 # Sample k times to evaluate the expectation over q(u)
concat_type = False # set True to concat e with z in generator G
expect_mean = True # Set True to compute mean of q(f) by average over q(u)

class ToyDeepGPHiddenLayer(DeepGPLayer):  # input_dims: size of feature dim
    def __init__(self, input_dims, output_dims, num_inducing=128, mean_type='constant', stein=False, noise_add=True,
                 noise_share=False, noise_dim=1, multi_head=False, vector_type=False):
        # TODO: adjust the size of inducing_points to [inducing_size, input_dim] for each layer
        if stein is False:
            if output_dims is None:
                inducing_points = torch.randn((num_inducing, input_dims), requires_grad=True) # sample from gaussian dist
                batch_shape = torch.Size([])
            else: # inducing_points-dim correspond to: output_dims, num_inducing and input_dims
                inducing_points = torch.randn((output_dims, num_inducing, input_dims), requires_grad=True)
                batch_shape = torch.Size([output_dims])
        else:
            if output_dims is None:
                inducing_points = torch.randn((num_inducing, input_dims), requires_grad=True) # sample from gaussian dist
                batch_shape = torch.Size([])
            else: # inducing_points-dim correspond to: [num_inducing, input_dims]
                inducing_points = torch.randn((num_inducing, input_dims), requires_grad=True)
                batch_shape = torch.Size([output_dims])
        
        self.input_dim = input_dims
        self.output_dim = output_dims
        self.inducing_size = num_inducing

        self.noise_add = noise_add
        self.noise_share = noise_share
        self.noise_dim = noise_dim # get this value from the outer model
        if self.noise_dim is None:
            self.noise_dim = 32
        self.noise_add = noise_add
        if self.noise_add:
            self.concat_type = concat_type
        else:
            self.concat_type = False # only can set True when adding noise

        # self.share_noise = share_noise # noise shared by all layers
        self.multi_head = multi_head
        # self.transformed_noise = transformed_noise
        self.learn_inducing_locations = True # Set True to treat location of z as parameter
        
        self.hutch_times = 1 # number for hutchison estimation of trace
        self.bottleneck_trick = True # Set True to introduce decomposition of jacobian to reduce variance
        self.rademacher_type = False # Set True to sample from Rademecher dist to compute trace
        self.diagonal_type = True # Set True to use diagonal kernel matrix to approximate fully kernel matrix
        self.vector_type = vector_type # Set True to generate u using vector-based network
        
        self.print_time = print_time
        self.print_param = print_param
        self.device = device
        self.sample_times = sample_times # Sample k times to evaluate the expectation over q(u)
        self.expect_mean = expect_mean
        
        variational_distribution = None

        variational_strategy = VariationalStrategy( # transform q(U) to q(F)
            self,
            inducing_points,
            variational_distribution,
            stein_type=stein,
            batch_shape=batch_shape
        )

        super(ToyDeepGPHiddenLayer, self).__init__(variational_strategy, input_dims, output_dims)        
        self.variational_strategy = variational_strategy

    def forward(self, x):
        return None
    
    # will be called when run the forward function of class DeepGP
    def __call__(self, x, train_type=True, disc_type=True, hyp_type=True, trace_layer_list=None, norm_layer_list=None, u_layer_list=None,
                 fiu_layer_list=None, glogpfu_layer_list=None, shared_list=None, transformed_list=None, *other_inputs,
                 **kwargs): 
        """
        Overriding __call__ isn't strictly necessary, but it lets us add concatenation based skip connections
        easily.
        """
        return super().__call__(x, are_samples=False, train_type=train_type, disc_type=disc_type,
                                hyp_type=hyp_type, trace_layer_list=trace_layer_list, norm_layer_list=norm_layer_list,
                                u_layer_list=u_layer_list, fiu_layer_list=fiu_layer_list, 
                                glogpfu_layer_list=glogpfu_layer_list, shared_list=shared_list,
                               transformed_list=transformed_list)

In [7]:
num_output_dims = 10  # number of output_dim in hidden layer
vector_type = True # Set True to generate u using vector-based network
learn_likelihood_covariance = True # Set True to treat the covariance of likelihood as parameter
if dataset == 'cifar10':
    task_dim = 10
    train_x_shape = 64 # carefully examine this term if self.feature_extractor is modified
elif dataset == 'mnist':
    task_dim = 10
    train_x_shape = 28 * 28 # carefully examine this term if self.feature_extractor is modified
elif dataset == 'fashion_mnist':
    task_dim = 10
    train_x_shape = 8 * 8 # carefully examine this term if self.feature_extractor is modified
elif dataset == 'red_wine':
    task_dim = 6
    train_x_shape = train_x.shape[-1]
elif dataset == 'white_wine':
    task_dim = 7
    train_x_shape = train_x.shape[-1]
elif dataset == 'abalone':
    task_dim = 29
    train_x_shape = train_x.shape[-1]
elif dataset == 'wilt':
    task_dim = 2
    train_x_shape = train_x.shape[-1]
        
# we can fine-tune it to obtain a better result
if dataset == 'red_wine':
    noise_dim = 200
    num_layer_inducing = 128
elif dataset == 'white_wine':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'cifar10':
    noise_dim = 200
    num_layer_inducing = 128
elif dataset == 'mnist':
    noise_dim = 200
    num_layer_inducing = 128
elif dataset == 'fashion_mnist':
    noise_dim = 200
    num_layer_inducing = 128
elif dataset == 'abalone':
    noise_dim = 32
    num_layer_inducing = 128
elif dataset == 'wilt':
    noise_dim = 32
    num_layer_inducing = 128

'''   
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d(2)
        
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.maxpool3 = nn.MaxPool2d(2)
        
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.maxpool4 = nn.MaxPool2d(2)
        
        
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc0 = nn.Linear(512, 256)
        self.bn0 = nn.BatchNorm1d(256)
        
        
        
    
        self.fc1 = nn.Linear(256, 128)
        self.bn11 = nn.BatchNorm1d(128)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, 64)

       
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.functional.relu(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.functional.relu(x)
        x = self.maxpool2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = nn.functional.relu(x)
        x = self.maxpool3(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = nn.functional.relu(x)
        x = self.maxpool4(x)
        
       
        
        
        x = self.avgpool(x)
        x = self.flatten(x)
        
        x = self.fc0(x)
        x = self.bn0(x)
        x = nn.functional.relu(x)
        
        
        
        
        x = self.fc1(x)
        x = self.bn11(x)
        x = nn.functional.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        
        
       
        return x
        
'''
class ResNet18FeatureExtractor(nn.Module):
    def __init__(self, num_classes=64):
        super(ResNet18FeatureExtractor, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(3, 64, 3, stride=1, padding=1, bias=False)
        self.resnet.maxpool = nn.MaxPool2d(1, 1, 0)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        x = self.resnet.conv1(x)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = self.resnet.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.resnet.fc(x)
        return x
Extractor = ResNet18FeatureExtractor().to(device)
state_dict = torch.load('ResNet_trained_extractor.pth')
Extractor.load_state_dict(state_dict)

        
class DeepGP(DeepGP): # define the noise outside the layer
    def __init__(self, train_x_shape, stein=False, noise_add=True, noise_share=False, multi_head=False): # L=2
        if noise_add is False and noise_share:
            raise ValueError('Only can share noise across layer when noise is added')
        if noise_share is False and multi_head:
            raise ValueError('Only can use Multi-head Mechanism when noise is shared across layer')
        super().__init__()
        self.vector_type = vector_type
        self.noise_add = noise_add
        self.noise_share = noise_share # share the noise across layer or not
        # fine-tune this term or directly learn as a prior
        self.noise_dim = noise_dim
        self.multi_head = multi_head  # transform noise or not        
        self.back_bone = nn.Sequential(
            nn.Linear(in_features=self.noise_dim, out_features=32),
            nn.Sigmoid(),
            nn.Linear(in_features=32, out_features=self.noise_dim)
        )
        if self.multi_head and print_param:
            print('Generator backbone:', get_parameter_number(self.back_bone))
        
        lls_sigma = torch.Tensor([0.1]) # Set sigma to be a small number
        if learn_likelihood_covariance:
            self.register_parameter("lls_sigma", torch.nn.Parameter(lls_sigma))
        else:
            self.register_buffer("lls_sigma", lls_sigma)
        
        if dataset == 'cifar10': # customize a feature_extractor for each image dataset
            '''
            self.feature_extractor = nn.Sequential(
                nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5),
                nn.MaxPool2d(2),
                nn.Conv2d(in_channels=3, out_channels=3, kernel_size=5),
                nn.MaxPool2d(2),
                nn.Flatten(), # output_dim: 3 * 32 * 32
            )
            '''
            with torch.no_grad():
                self.feature_extractor = Extractor.eval().to(device)
            for param in self.feature_extractor.parameters():
                param.requires_grad = False
                #self.feature_extractor = ResNetFeatureExtractor.eval().to(device)
            #self.feature_extractor=FeatureExtractor()
        elif dataset == 'mnist' or 'fashion_mnist':
            self.feature_extractor = nn.Sequential(
                nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5),
                nn.MaxPool2d(2),
                nn.Conv2d(in_channels=4, out_channels=4, kernel_size=5),
                nn.MaxPool2d(2),
                nn.Flatten(), # output_dim: 28 * 28
            )
        elif dataset == 'red_wine' or 'white_wine' or 'abalone' or 'wilt':
            self.feature_extractor = nn.Flatten()

        self.post_classification = nn.Sequential(
            #nn.Linear(in_features=task_dim, out_features=task_dim),
            #nn.Sigmoid(),
            #nn.Linear(in_features=task_dim, out_features=task_dim),
            nn.Softmax(dim=1),
        )
        self.cross = nn.CrossEntropyLoss() # compute the classification loss to optimize hyper-parameter
        
        hidden_layer = ToyDeepGPHiddenLayer(
            input_dims=train_x_shape,
            output_dims=num_output_dims,
            num_inducing=num_layer_inducing,
            mean_type='linear',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )
        '''
        hidden_layer2 = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer.output_dim,
            output_dims=num_output_dims,
            num_inducing=num_layer_inducing,
            mean_type='linear',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )   
        
        
        hidden_layer3 = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer2.output_dim,
            output_dims=num_output_dims,
            num_inducing=num_layer_inducing,
            mean_type='linear',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )
        
        
        hidden_layer4 = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer3.output_dim,
            output_dims=num_output_dims,
            num_inducing=num_layer_inducing,
            mean_type='linear',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )
        '''

        last_layer = ToyDeepGPHiddenLayer(
            input_dims=hidden_layer.output_dim,
            output_dims=task_dim,
            num_inducing=num_layer_inducing,
            mean_type='constant',
            stein=stein,
            noise_add=noise_add,
            noise_share=self.noise_share,
            noise_dim=self.noise_dim,
            multi_head=self.multi_head,
            vector_type=self.vector_type,
        )

        self.hidden_layer = hidden_layer
        #self.hidden_layer2 = hidden_layer2
        #self.hidden_layer3 = hidden_layer3
        # self.hidden_layer4 = hidden_layer4
        self.last_layer = last_layer
        
        # register all networks here!!!
        self.hidden_strategy_generator = self.hidden_layer.variational_strategy.generator
        self.hidden_strategy_discriminator = self.hidden_layer.variational_strategy.discriminator
        self.hidden_kernel_method = self.hidden_layer.variational_strategy.kernel_method
        '''
        self.hidden2_strategy_generator = self.hidden_layer2.variational_strategy.generator
        self.hidden2_strategy_discriminator = self.hidden_layer2.variational_strategy.discriminator
        self.hidden2_kernel_method = self.hidden_layer2.variational_strategy.kernel_method
        
        self.hidden3_strategy_generator = self.hidden_layer3.variational_strategy.generator
        self.hidden3_strategy_discriminator = self.hidden_layer3.variational_strategy.discriminator
        self.hidden3_kernel_method = self.hidden_layer3.variational_strategy.kernel_method
        
        self.hidden4_strategy_generator = self.hidden_layer4.variational_strategy.generator
        self.hidden4_strategy_discriminator = self.hidden_layer4.variational_strategy.discriminator
        self.hidden4_kernel_method = self.hidden_layer4.variational_strategy.kernel_method
        '''
        self.last_strategy_generator = self.last_layer.variational_strategy.generator
        self.last_strategy_discriminator = self.last_layer.variational_strategy.discriminator
        self.last_kernel_method = self.last_layer.variational_strategy.kernel_method

    def forward(self, inputs):
        # set train_type=False when evaluate, do not compute the loss to faster the process
        # Generate noise here when noise is shared or transformed
        if self.noise_add:
            if self.noise_share:
                if self.multi_head:
                    if self.vector_type:
                        shared_list = []
                        transformed_list = []
                        for i in range(0, sample_times):
                            shared_noise = 0.1*torch.randn(self.noise_dim, requires_grad=True).to(device)
                            transformed_noise = self.back_bone(shared_noise)
                            shared_list.append(shared_noise)
                            transformed_list.append(transformed_noise)
                    else:
                        shared_list = []
                        transformed_list = []
                        for i in range(0, sample_times):
                            shared_noise = torch.randn((num_layer_inducing, self.noise_dim), requires_grad=True).to(device)
                            transformed_noise = self.back_bone(shared_noise)
                            shared_list.append(shared_noise)
                            transformed_list.append(transformed_noise)
                else:
                    if self.vector_type:
                        shared_list = []
                        transformed_list = None
                        for i in range(0, sample_times):
                            shared_noise = torch.randn(self.noise_dim, requires_grad=True).to(device)
                            shared_list.append(shared_noise)                        
                    else:
                        shared_list = []
                        transformed_list = None
                        for i in range(0, sample_times):
                            shared_noise = torch.randn((num_layer_inducing, self.noise_dim), requires_grad=True).to(device)
                            shared_list.append(shared_noise)  
            else:
                shared_list = None
                transformed_list = None
        else:
            shared_list = None
            transformed_list = None
        inputs, train_type, disc_type, hyp_type, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = inputs
        # determine to train backbone network and parameter sigma or not
        if train_type:
            if disc_type:
                self.back_bone.requires_grad_(False)
                self.lls_sigma.requires_grad_(False)
                self.feature_extractor.requires_grad_(False)
                self.post_classification.requires_grad_(False)
            else:
                self.back_bone.requires_grad_(True)
                self.lls_sigma.requires_grad_(False)
                self.feature_extractor.requires_grad_(False)
                self.post_classification.requires_grad_(False)
        else:
            if hyp_type:
                self.back_bone.requires_grad_(False)
                self.lls_sigma.requires_grad_(True)
                self.feature_extractor.requires_grad_(False)
                self.post_classification.requires_grad_(True)
            else: # use for test stage
                self.back_bone.requires_grad_(False)
                self.lls_sigma.requires_grad_(False)
                self.feature_extractor.requires_grad_(False)
                self.post_classification.requires_grad_(False)
        
        hidden_rep1, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = self.hidden_layer(
                                                                            inputs, train_type=train_type, disc_type=disc_type,
                                                                           hyp_type=hyp_type, trace_layer_list=trace_layer_list,
                                                                           norm_layer_list=norm_layer_list,
                                                                           u_layer_list=u_layer_list,
                                                                          fiu_layer_list=fiu_layer_list,
                                                                          glogpfu_layer_list=glogpfu_layer_list,
                                                                          shared_list=shared_list,
                                                                          transformed_list=transformed_list)
        '''
        hidden_rep2, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = self.hidden_layer2(
                                                                            hidden_rep1, train_type=train_type, disc_type=disc_type,
                                                                           hyp_type=hyp_type, trace_layer_list=trace_layer_list,
                                                                           norm_layer_list=norm_layer_list,
                                                                           u_layer_list=u_layer_list,
                                                                          fiu_layer_list=fiu_layer_list,
                                                                          glogpfu_layer_list=glogpfu_layer_list,
                                                                          shared_list=shared_list,
                                                                          transformed_list=transformed_list)
        
        hidden_rep3, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = self.hidden_layer3(
                                                                            hidden_rep2, train_type=train_type, disc_type=disc_type,
                                                                           hyp_type=hyp_type, trace_layer_list=trace_layer_list,
                                                                           norm_layer_list=norm_layer_list,
                                                                           u_layer_list=u_layer_list,
                                                                          fiu_layer_list=fiu_layer_list,
                                                                          glogpfu_layer_list=glogpfu_layer_list,
                                                                          shared_list=shared_list,
                                                                          transformed_list=transformed_list)
        
        hidden_rep4, trace_layer_list, norm_layer_list, u_layer_list, fiu_layer_list, glogpfu_layer_list = self.hidden_layer4(
                                                                            hidden_rep3, train_type=train_type, disc_type=disc_type,
                                                                           hyp_type=hyp_type, trace_layer_list=trace_layer_list,
                                                                           norm_layer_list=norm_layer_list,
                                                                           u_layer_list=u_layer_list,
                                                                          fiu_layer_list=fiu_layer_list,
                                                                          glogpfu_layer_list=glogpfu_layer_list,
                                                                          shared_list=shared_list,
                                                                          transformed_list=transformed_list)
        '''
        
        # get a prob distribution, not a deterministic vector value
        output = self.last_layer(hidden_rep1, train_type=train_type, disc_type=disc_type, hyp_type=hyp_type, 
                                 trace_layer_list=trace_layer_list, norm_layer_list=norm_layer_list, 
                                 u_layer_list=u_layer_list, fiu_layer_list=fiu_layer_list, 
                                 glogpfu_layer_list=glogpfu_layer_list, shared_list=shared_list, 
                                 transformed_list=transformed_list)
        return output

    def predict(self, test_loader):
        correct_sum = 0
        total_sum = 0
        for x_batch, y_batch in test_loader:
            x_batch, y_batch = x_batch.to(device), y_batch.to(device)
            x_batch = self.feature_extractor.forward(x_batch)
            output_batch, _, _, _, _, _ = self.forward((x_batch, False, False, True, [], [], [], [], []))
            mean_batch, covar_batch = output_batch
            f_batch_list = []
            for t in range(0, task_dim):
                f_batch = torch.distributions.MultivariateNormal(loc=mean_batch[t, :], covariance_matrix=covar_batch[t, :, :]).rsample(
                        torch.Size([])) # sample f from q(f)
                f_batch_list.append(f_batch)
            f_batch_total = torch.stack(f_batch_list, dim=1) # size: [batch_size, task_dim]
            preds = self.classification_likelihood(f_batch_total)
            #print(preds)
            #print(y_batch)
            correct = (preds == y_batch).sum().item()
            total = y_batch.size(0)
            print('Batch Accuracy: {}%'.format(100 * correct / total))

            correct_sum += correct
            total_sum += total
                

        total_acc = 100 * correct_sum / total_sum 
        return total_acc
    
    def classification_likelihood_loss(self, y_batch, output_batch, u_layer_list, fiu_layer_list, glogpfu_layer_list):
        mean_batch, covariance_batch = output_batch
        f_batch_list = []
        for t in range(0, task_dim):
            f_batch = torch.distributions.MultivariateNormal(loc=mean_batch[t, :], 
                                                             covariance_matrix=covariance_batch[t, :, :]).rsample(torch.Size([]))
            f_batch_list.append(f_batch)
        f_batch_total = torch.stack(f_batch_list, dim=1) # size: [batch_size, task_dim]
        point_batch = self.post_classification(f_batch_total) # size: [batch_size, task_dim]
        noise_batch = torch.ones_like(f_batch_total).mul(1e-4).to(device) # remove it if training of classification is stable
        point_batch = point_batch.add(noise_batch)
        #f_batch_total = f_batch_total.add(noise_batch)
        y_size = y_batch.size(0)
        one_hot_batch = torch.zeros(y_size, task_dim).long().to(device)
        one_hot_batch.scatter_(dim=1, 
                               index=y_batch.unsqueeze(dim=1), 
                               src=torch.ones(y_size, task_dim).long().to(device)) # [batch_size, task_size]
        prob_batch = torch.matmul(point_batch.unsqueeze(1), one_hot_batch.float().unsqueeze(2))
        prob_batch = prob_batch.squeeze(2).squeeze(1)
        log_prob_batch = torch.log(prob_batch).sum(0) # get the final log-prob value
        
        glogpyf_layer_list = [] # use to contain gradient of likelihood
        for l in range(0, len(u_layer_list)):
            glogpyf_l_list = []
            for k in range(0, sample_times):
                if expect_mean:
                    glogpyf_l = torch.autograd.grad(log_prob_batch, u_layer_list[l][k], retain_graph=True, create_graph=True)[0]
                else:
                    # Set expect_mean == False to generate q(f) only using the first element of u_layer_list
                    glogpyf_l = torch.autograd.grad(log_prob_batch, u_layer_list[l][0], retain_graph=True, create_graph=True)[0]
                glogpyf_l_list.append(glogpyf_l)
            glogpyf_layer_list.append(glogpyf_l_list)
        
        glogpyffu_layer_list = []
        # glogpyffu_layer_list: the second term in the first loss term in each layer
        for l in range(0, len(u_layer_list)):
            glogpyffu_l_list = []
            for k in range(0, sample_times):
                glogpyffu_l = torch.mm(glogpyf_layer_list[l][k].unsqueeze(0), fiu_layer_list[l][k].unsqueeze(1))
                glogpyffu_l_list.append(glogpyffu_l)
            glogpyffu_l_tensor = torch.stack(glogpyffu_l_list, dim=0)
            glogpyffu_l_mean = glogpyffu_l_tensor.mean(0) # average over sample_times
            glogpyffu_layer_list.append(glogpyffu_l_mean)
        glogpyffu_tensor = torch.stack(glogpyffu_layer_list, dim=0)
        glogpyffu_loss = glogpyffu_tensor.sum(0) # sum over all layers
        
        return glogpyffu_loss
    
    def classification_likelihood(self, f_batch):
        # size of f_batch: [batch_size, task_dim]
        # We return the integer of class number that we predict the image belongs to
        #prob_batch = self.post_classification(f_batch)
        #y_predict = torch.max(prob_batch, dim=1)[1] # return the index [0, L-1] of each data point
        y_predict = torch.max(f_batch, dim=1)[1] # return the index [0, L-1] of each data point
        return y_predict

In [8]:
# NOTE: When strange error occur (especially wrong error line), restart the kernel
# If multi_head=True, Transform shared noise by a network and concat with layer-specified inducing points
model = DeepGP(train_x_shape, stein=True, noise_add=True, noise_share=True, multi_head=True)
print('Test whether you have cuda to run the process or not:', torch.cuda.is_available())
model = model.to(device)
print(get_parameter_number(model))


Generator backbone: {'Total': 13032, 'Trainable': 13032}
Generator: {'Total': 56864, 'Trainable': 56864}
Discriminator: {'Total': 42257, 'Trainable': 42257}
Kernel method: {'Total': 66, 'Trainable': 66}
Generator: {'Total': 49952, 'Trainable': 49952}
Discriminator: {'Total': 42257, 'Trainable': 42257}
Kernel method: {'Total': 12, 'Trainable': 12}
Test whether you have cuda to run the process or not: True
{'Total': 11406105, 'Trainable': 204441}


In [9]:
acc_test_list = [] 
acc_train_list = []

In [10]:
print(get_parameter_number(model))

{'Total': 11406105, 'Trainable': 204441}


In [24]:

from gpytorch.settings import num_likelihood_samples
begin_sum = time.time()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)
optimizer1 = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.00025)
optimizer_gen = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

#def train(num_epochs=3, num_disc=2, num_gen=2, num_hyp=10, num_samples=10, lamda=100, lamda_like=1):

num_epochs = 1
num_disc = 1 # use to train discriminator per epoch
num_gen = 1# use to train generator per epoch
num_hyp = 10 # use to train hyper-parameter per epoch
num_samples = 10 # use to evaluate likelihood
lamda = 100 # hyperparameter for norm loss
lamda_like = 1 # hyperparameter for likelihood loss (will be deleted soon)

if num_epochs == 0:
    raise ValueError('At least one epoch need to be run!!!')

epochs_iter = tqdm.notebook.tqdm(range(num_epochs), desc="Epoch")
for i in epochs_iter: # change here to add adversarial training
    print('New epoch!!!')
    if num_disc != 0:
        correct_epoch_disc = 0 # use to calculate correct prediction in one epoch
        total_epoch_disc = 0 # use to calculate total prediction in one epoch
        # fix hyperparameter v and generator G, update discriminator D num_nc times
        disc_iter = tqdm.notebook.tqdm(range(num_disc), desc='Train Discriminator', leave=False)
        for j in disc_iter:
            # Within each iteration, we will go over each minibatch of data
            minibatch_iter_disc = tqdm.notebook.tqdm(train_loader1, desc="Minibatch-Disc", leave=False)
            correct_iter_disc = 0 # use to calculate correct prediction in one iteration
            total_iter_disc = 0 # use to calculate total prediction in one iteration
            for x_batch_disc, y_batch_disc in minibatch_iter_disc:
                x_batch_disc, y_batch_disc = x_batch_disc.to(device), y_batch_disc.to(device)
                with num_likelihood_samples(num_samples):
                    optimizer.zero_grad()
                    trace_layer_list_disc = []
                    norm_layer_list_disc = []
                    u_layer_list_disc = []
                    fiu_layer_list_disc = []
                    glogpfu_layer_list_disc = []
                    # trace_layer_list: contain the final trace in each layer
                    # norm_layer_list: contain the norm in each layer
                    # u_layer_list: contain u in each layer
                    # fiu_layer_list: contain f_u in each layer
                    # glogpu_layer_list: contain dot product of gradient of logp_u with f_u in each layer
                    # disc_type: determine update discriminator or generator
                    x_batch_disc = model.feature_extractor.forward(x_batch_disc)
                    output_batch_disc, trace_layer_list_disc, norm_layer_list_disc, u_layer_list_disc, fiu_layer_list_disc, glogpfu_layer_list_disc = model.forward(
                        (x_batch_disc, True, True, False, trace_layer_list_disc, norm_layer_list_disc, u_layer_list_disc, fiu_layer_list_disc, glogpfu_layer_list_disc))
                    # size of output:([1, batch_size], [1, batch_size, batch_size])
                    # the second loss term
                    trace_layer_tensor_disc = torch.stack(trace_layer_list_disc, dim=0)
                    trace_loss_disc = trace_layer_tensor_disc.sum(0)  
                    # the third loss term
                    norm_layer_tensor_disc = torch.stack(norm_layer_list_disc, dim=0)
                    norm_loss_disc = norm_layer_tensor_disc.sum(0)
                    if print_norm:
                        print('DISC Norm of f_u:', norm_loss_disc.item())
                    norm_loss_disc = lamda * norm_loss_disc # need to negativate this loss 
                    # the first term in the first loss term 
                    glogpfu_tensor_disc = torch.stack(glogpfu_layer_list_disc, dim=0)
                    glogpu_loss_disc = glogpfu_tensor_disc.sum(0)
                    glogpu_loss_disc = -1 * glogpu_loss_disc
                    # the second term in the first loss term
                    # call the API instead of setting our own function
                    begin = time.time()                        
                    glogpyffu_loss_disc = lamda_like * model.classification_likelihood_loss(y_batch_disc, 
                                            output_batch_disc, u_layer_list_disc, fiu_layer_list_disc, 
                                                glogpfu_layer_list_disc)                            
                    end = time.time()
                    if print_time:
                        print('Compute likelihood time:', str(end - begin), 's')
                    # the first loss term
                    score_loss_disc = glogpu_loss_disc + glogpyffu_loss_disc
                    # change here to form our own loss
                    original_loss_disc = torch.abs(score_loss_disc + trace_loss_disc)
                    loss_disc = original_loss_disc - norm_loss_disc
                    if print_loss:
                        print('DISC prior loss:', glogpu_loss_disc.item())
                        print('DISC likelihood loss:', glogpyffu_loss_disc.item())
                        print('DISC trace loss:', trace_loss_disc.item())
                        print('DISC original loss:', original_loss_disc.item())
                        print('DISC norm loss:', norm_loss_disc.item())
                        print('DISC total loss:', loss_disc.item())

                    torch.autograd.backward(-loss_disc) # nevigate it to max this loss
                    optimizer.step()
                    minibatch_iter_disc.set_postfix(loss=loss_disc.item())

                    mean_batch_disc, covar_batch_disc = output_batch_disc

                    f_batch_list_disc = []
                    for t in range(0, task_dim):
                        f_line_disc = torch.distributions.MultivariateNormal(loc=mean_batch_disc[t, :], 
                                                                              covariance_matrix=covar_batch_disc[t, :, :]).rsample(
                            torch.Size([])) # sample f from q(f)
                        f_batch_list_disc.append(f_line_disc)
                    f_batch_disc = torch.stack(f_batch_list_disc, dim=1) # size: [batch_size, task_dim]
                    '''
                    f_batch_disc = torch.distributions.MultivariateNormal(loc=mean_batch_disc, 
                                                                          covariance_matrix=covar_batch_disc).rsample(torch.Size([]))
                    f_batch_disc = f_batch_disc.squeeze(0) # size: [batch_size]
                    '''
                    y_predict_disc = model.classification_likelihood(f_batch_disc)
                    correct_disc = (y_predict_disc == y_batch_disc).sum().item()
                    total_disc = y_batch_disc.size(0)
                    acc_disc = 100 * correct_disc / total_disc
                    if print_metric:
                        print(f"MINI-BATCH DISC_ACC: {acc_disc}%")
                    correct_iter_disc += correct_disc
                    total_iter_disc += total_disc

            acc_iter_disc = 100 * correct_iter_disc / total_iter_disc
            if print_metric:
                print(f"ITER DISC_ACC: {acc_iter_disc}%")
            correct_epoch_disc += correct_iter_disc
            total_epoch_disc += total_iter_disc

        acc_epoch_disc = 100 * correct_epoch_disc / total_epoch_disc
        if print_metric:
            print(f"EPOCH DISC_ACC: {acc_epoch_disc}%")

        if print_loss:
            print('-----------------------------------')

    if num_gen != 0:
        correct_epoch_gen = 0 # use to calculate correct prediction in one epoch
        total_epoch_gen = 0 # use to calculate total prediction in one epoch
        # fix hyperparameter v and generator G, update discriminator D num_nc times
        gen_iter = tqdm.notebook.tqdm(range(num_gen), desc='Train Generator', leave=False)
        for j in gen_iter:
            # Within each iteration, we will go over each minibatch of data
            minibatch_iter_gen = tqdm.notebook.tqdm(train_loader1, desc="Minibatch-Gen", leave=False)
            correct_iter_gen = 0 # use to calculate correct prediction in one iteration
            total_iter_gen = 0 # use to calculate total prediction in one iteration
            # fix discriminator D and train hyperparameter v with generator G
            for x_batch_gen, y_batch_gen in minibatch_iter_gen:
                x_batch_gen, y_batch_gen = x_batch_gen.to(device), y_batch_gen.to(device)
                with num_likelihood_samples(num_samples):
                    optimizer_gen.zero_grad()
                    trace_layer_list_gen = []
                    norm_layer_list_gen = []
                    u_layer_list_gen = []
                    fiu_layer_list_gen = []
                    glogpfu_layer_list_gen = []
                    x_batch_gen = model.feature_extractor.forward(x_batch_gen)
                    output_batch_gen, trace_layer_list_gen, norm_layer_list_gen, u_layer_list_gen, fiu_layer_list_gen, glogpfu_layer_list_gen = model.forward(
                                                                        (x_batch_gen, True, False, False, trace_layer_list_gen, 
                                                                        norm_layer_list_gen, u_layer_list_gen, fiu_layer_list_gen, 
                                                                        glogpfu_layer_list_gen))
                    trace_layer_tensor_gen = torch.stack(trace_layer_list_gen, dim=0)
                    trace_loss_gen = trace_layer_tensor_gen.sum(0)

                    norm_layer_tensor_gen = torch.stack(norm_layer_list_gen, dim=0)
                    norm_loss_gen = norm_layer_tensor_gen.sum(0)
                    norm_loss_gen = lamda * norm_loss_gen # need to negativate this loss

                    glogpfu_tensor_gen = torch.stack(glogpfu_layer_list_gen, dim=0)
                    glogpu_loss_gen = glogpfu_tensor_gen.sum(0)
                    glogpu_loss_gen = -1 * glogpu_loss_gen  # occur wrong when update generator G

                    # the second term in the first loss term
                    # call the API instead of setting our own function
                    begin = time.time()
                    glogpyffu_loss_gen = lamda_like * model.classification_likelihood_loss(y_batch_gen, 
                                            output_batch_gen, u_layer_list_gen, fiu_layer_list_gen, 
                                                glogpfu_layer_list_gen)
                    end = time.time()
                    if print_time:
                        print('Compute likelihood time:', str(end - begin), 's')
                    # the first loss term
                    score_loss_gen = glogpu_loss_gen + glogpyffu_loss_gen
                    # change here to form our own loss
                    original_loss_gen = torch.abs(score_loss_gen + trace_loss_gen)
                    loss_gen = original_loss_gen - norm_loss_gen

                    if print_loss:
                        print('GEN prior loss:', glogpu_loss_gen.item())
                        print('GEN likelihood loss:', glogpyffu_loss_gen.item())
                        print('GEN trace loss:', trace_loss_gen.item())
                        print('GEN original loss:', original_loss_gen.item())
                        print('GEN norm loss:', norm_loss_gen.item())
                        print('GEN total loss:', loss_gen.item())

                    torch.autograd.backward(loss_gen)
                    optimizer_gen.step()
                    minibatch_iter_gen.set_postfix(loss=loss_gen.item())

                    mean_batch_gen, covar_batch_gen = output_batch_gen

                    f_batch_list_gen = []
                    for t in range(0, task_dim):
                        f_line_gen = torch.distributions.MultivariateNormal(loc=mean_batch_gen[t, :], 
                                                                              covariance_matrix=covar_batch_gen[t, :, :]).rsample(
                            torch.Size([])) # sample f from q(f)
                        f_batch_list_gen.append(f_line_gen)
                    f_batch_gen = torch.stack(f_batch_list_gen, dim=1)
                    '''
                    f_batch_gen = torch.distributions.MultivariateNormal(loc=mean_batch_gen, 
                                                                          covariance_matrix=covar_batch_gen).rsample(torch.Size([]))
                    f_batch_gen = f_batch_gen.squeeze(0) # size: [batch_size]
                    '''
                    y_predict_gen = model.classification_likelihood(f_batch_gen)
                    correct_gen = (y_predict_gen == y_batch_gen).sum().item()
                    total_gen = y_batch_gen.size(0)
                    acc_gen = 100 * correct_gen / total_gen
                    if print_metric:
                        print(f"MINI-BATCH GEN_ACC: {acc_gen}%")
                    correct_iter_gen += correct_gen
                    total_iter_gen += total_gen

            acc_iter_gen = 100 * correct_iter_gen / total_iter_gen
            if print_metric:
                print(f"ITER GEN_ACC: {acc_iter_gen}%")
            correct_epoch_gen += correct_iter_gen
            total_epoch_gen += total_iter_gen

        acc_epoch_gen = 100 * correct_epoch_gen / total_epoch_gen
        if print_metric:
            print(f"EPOCH GEN_ACC: {acc_epoch_gen}%")

        if print_loss:
            print('-----------------------------------')

    if num_hyp != 0:
        correct_epoch_hyp = 0 # use to calculate correct prediction in one epoch
        total_epoch_hyp = 0 # use to calculate total prediction in one epoch
        # fix hyperparameter v and generator G, update discriminator D num_nc times
        hyp_iter = tqdm.notebook.tqdm(range(num_hyp), desc='Train Hyper-parameter', leave=False)
        for j in hyp_iter:
            # Within each iteration, we will go over each minibatch of data
            minibatch_iter_hyp = tqdm.notebook.tqdm(train_loader, desc="Minibatch-Hyper", leave=False)
            correct_iter_hyp = 0 # use to calculate correct prediction in one iteration
            total_iter_hyp = 0 # use to calculate total prediction in one iteration
            # fix discriminator D and train hyperparameter v with generator G
            for x_batch_hyp, y_batch_hyp in minibatch_iter_hyp:
                x_batch_hyp, y_batch_hyp = x_batch_hyp.to(device), y_batch_hyp.to(device)
                with num_likelihood_samples(num_samples):
                    optimizer1.zero_grad()
                    x_batch_hyp = model.feature_extractor.forward(x_batch_hyp)
                   
                    
                    output_batch_hyp, _, _, _, _, _ = model.forward((x_batch_hyp, False, False, True, [], [], [], [], []))
                    # call the API instead of setting our own function
                    mean_batch_hyp, covar_batch_hyp = output_batch_hyp

                    f_batch_list_hyp = []
                    for t in range(0, task_dim):
                        f_line_hyp = torch.distributions.MultivariateNormal(loc=mean_batch_hyp[t, :], 
                                                                              covariance_matrix=covar_batch_hyp[t, :, :]).rsample(
                            torch.Size([])) # sample f from q(f)
                        f_batch_list_hyp.append(f_line_hyp)
                    f_batch_hyp = torch.stack(f_batch_list_hyp, dim=1)
                    '''
                    f_batch_hyp = torch.distributions.MultivariateNormal(loc=mean_batch_hyp, 
                                                                          covariance_matrix=covar_batch_hyp).rsample(torch.Size([]))
                    f_batch_hyp = f_batch_hyp.squeeze(0) # size: [batch_size]
                    '''
                    #prob_hyp = model.post_classification(f_batch_hyp) # may remove softmax to use unnormalized score
                    #loss_hyp = model.cross(prob_hyp, y_batch_hyp)
                    model.feature_extractor.requires_grad_(True)
                    loss_hyp = model.cross(f_batch_hyp, y_batch_hyp)

                    if print_loss:
                            print('HYP total loss:', loss_hyp.item())
                    torch.autograd.backward(loss_hyp)
                    optimizer1.step()
                    minibatch_iter_hyp.set_postfix(loss=loss_hyp.item())

                    y_predict_hyp = model.classification_likelihood(f_batch_hyp)
                    #print(y_predict_hyp)
                    #print(y_batch_hyp)
                    correct_hyp = (y_predict_hyp == y_batch_hyp).sum().item()
                    total_hyp = y_batch_hyp.size(0)
                    acc_hyp = 100 * correct_hyp / total_hyp
                    if print_metric:
                        print(f"MINI-BATCH HYP_ACC: {acc_hyp}%")
                    correct_iter_hyp += correct_hyp
                    total_iter_hyp += total_hyp

            acc_iter_hyp = 100 * correct_iter_hyp / total_iter_hyp
            if print_metric:
                print(f"ITER HYP_ACC: {acc_iter_hyp}%")
            correct_epoch_hyp += correct_iter_hyp
            total_epoch_hyp += total_iter_hyp
            
            # store test rmse each iteration
            acc_test_hyp= model.predict(test_loader)

            acc_test_list.append(acc_test_hyp)
            # store test rmse each iteration
            acc_train_list.append(acc_iter_hyp)


        acc_epoch_hyp = 100 * correct_epoch_hyp / total_epoch_hyp
        if print_metric:
            print(f"EPOCH HYP_ACC: {acc_epoch_hyp}%")

        if print_loss:
            print('-----------------------------------')
end_sum = time.time()            

Epoch:   0%|          | 0/1 [00:00<?, ?it/s]

New epoch!!!


Train Discriminator:   0%|          | 0/1 [00:00<?, ?it/s]

Minibatch-Disc:   0%|          | 0/196 [00:00<?, ?it/s]

MINI-BATCH DISC_ACC: 98.828125%
MINI-BATCH DISC_ACC: 98.828125%
MINI-BATCH DISC_ACC: 99.21875%
MINI-BATCH DISC_ACC: 98.4375%
MINI-BATCH DISC_ACC: 99.21875%
MINI-BATCH DISC_ACC: 99.21875%
MINI-BATCH DISC_ACC: 99.609375%
MINI-BATCH DISC_ACC: 98.4375%
MINI-BATCH DISC_ACC: 99.609375%
MINI-BATCH DISC_ACC: 98.4375%
MINI-BATCH DISC_ACC: 98.828125%
MINI-BATCH DISC_ACC: 98.4375%
MINI-BATCH DISC_ACC: 99.21875%
MINI-BATCH DISC_ACC: 99.21875%
MINI-BATCH DISC_ACC: 99.21875%
MINI-BATCH DISC_ACC: 99.609375%
MINI-BATCH DISC_ACC: 98.4375%
MINI-BATCH DISC_ACC: 99.609375%
MINI-BATCH DISC_ACC: 98.046875%
MINI-BATCH DISC_ACC: 99.609375%
MINI-BATCH DISC_ACC: 99.609375%
MINI-BATCH DISC_ACC: 98.4375%
MINI-BATCH DISC_ACC: 98.4375%
MINI-BATCH DISC_ACC: 98.4375%
MINI-BATCH DISC_ACC: 98.4375%
MINI-BATCH DISC_ACC: 99.21875%
MINI-BATCH DISC_ACC: 99.609375%
MINI-BATCH DISC_ACC: 99.21875%
MINI-BATCH DISC_ACC: 99.609375%
MINI-BATCH DISC_ACC: 99.21875%
MINI-BATCH DISC_ACC: 100.0%
MINI-BATCH DISC_ACC: 98.046875%
MINI-BA

Train Generator:   0%|          | 0/1 [00:00<?, ?it/s]

Minibatch-Gen:   0%|          | 0/196 [00:00<?, ?it/s]

MINI-BATCH GEN_ACC: 98.4375%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 99.609375%
MINI-BATCH GEN_ACC: 100.0%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 99.21875%
MINI-BATCH GEN_ACC: 99.21875%
MINI-BATCH GEN_ACC: 99.609375%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 98.4375%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 97.65625%
MINI-BATCH GEN_ACC: 99.609375%
MINI-BATCH GEN_ACC: 99.21875%
MINI-BATCH GEN_ACC: 100.0%
MINI-BATCH GEN_ACC: 99.609375%
MINI-BATCH GEN_ACC: 98.4375%
MINI-BATCH GEN_ACC: 98.4375%
MINI-BATCH GEN_ACC: 99.609375%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 99.609375%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 98.4375%
MINI-BATCH GEN_ACC: 98.828125%
MINI-BATCH GEN_ACC: 99.609375%
MINI-BATCH GEN_ACC: 99.609375%
MINI-BATCH GEN_ACC: 99.21875%
MINI-BATCH GEN_ACC: 99.21875%
M

Train Hyper-parameter:   0%|          | 0/10 [00:00<?, ?it/s]

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 9

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 96.09375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BAT

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_A

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 96.09375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 96.09375%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
M

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 96.09375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC:

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 1

Minibatch-Hyper:   0%|          | 0/391 [00:00<?, ?it/s]

MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 96.875%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 100.0%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 97.65625%
MINI-BATCH HYP_ACC: 98.4375%
MINI-BATCH HYP_ACC: 99.21875%
MINI-BATCH HYP_ACC: 97.

In [10]:
## torch.autograd.set_detect_anomaly(True)

# training stage
#print('----------Train Disc and Gen----------')
#train(num_epochs=1, num_disc=0, num_gen=0, num_hyp=2, num_samples=10, lamda=100, lamda_like=1)
#print('----------Train Hyp----------')
#train(num_epochs=1, num_disc=1, num_gen=1, num_hyp=3, num_samples=10, lamda=2000, lamda_like=1)
#print('Training stage finish!!!')

In [25]:
import gpytorch
import math

# evaluate stage
model.eval()


total_acc = model.predict(test_loader)
print('Total Accuracy: {}%'.format(total_acc))

Batch Accuracy: 92.96875%
Batch Accuracy: 94.921875%
Batch Accuracy: 91.796875%
Batch Accuracy: 94.140625%
Batch Accuracy: 94.140625%
Batch Accuracy: 94.53125%
Batch Accuracy: 94.53125%
Batch Accuracy: 89.84375%
Batch Accuracy: 91.015625%
Batch Accuracy: 91.796875%
Batch Accuracy: 94.53125%
Batch Accuracy: 92.1875%
Batch Accuracy: 92.1875%
Batch Accuracy: 92.578125%
Batch Accuracy: 91.796875%
Batch Accuracy: 91.40625%
Batch Accuracy: 96.09375%
Batch Accuracy: 93.359375%
Batch Accuracy: 92.96875%
Batch Accuracy: 94.140625%
Batch Accuracy: 94.53125%
Batch Accuracy: 94.921875%
Batch Accuracy: 91.015625%
Batch Accuracy: 94.921875%
Batch Accuracy: 95.3125%
Batch Accuracy: 92.578125%
Batch Accuracy: 93.75%
Batch Accuracy: 92.578125%
Batch Accuracy: 95.3125%
Batch Accuracy: 93.359375%
Batch Accuracy: 92.578125%
Batch Accuracy: 93.359375%
Batch Accuracy: 92.96875%
Batch Accuracy: 91.015625%
Batch Accuracy: 93.359375%
Batch Accuracy: 94.53125%
Batch Accuracy: 88.28125%
Batch Accuracy: 95.3125%


In [1]:
import torch
torch.cuda.empty_cache()

In [26]:
print('Write test acc to experiments/cifar10.xlsx')
df = pd.DataFrame(acc_test_list, columns=['ACC'])
df.to_excel('experiments/twonorm5_test.xlsx', index=False)
df_train = pd.DataFrame(acc_train_list, columns=['ACC'])
df_train.to_excel('experiments/twonorm5_train.xlsx', index=False)

Write test acc to experiments/cifar10.xlsx


In [20]:
torch.save(model.state_dict(), 'novi.pt')
file_size = os.path.getsize('novi.pt')
print(f"Model size: {file_size / 1024 / 1024:.2f} MB")
            
              

Model size: 43.61 MB


In [2]:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision
# define data transforms

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])
'''
transform_train = transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
transform_test = transforms.Compose([
        torchvision.transforms.ToTensor(), 
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
'''    

# load CIFAR-10 dataset
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

# define data loaders
train_loader = DataLoader(trainset, batch_size=128, shuffle=True)
test_loader = DataLoader(testset, batch_size=128, shuffle=False)

# define the model

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool1 = nn.MaxPool2d(2)
        
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.maxpool2 = nn.MaxPool2d(2)
        
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.maxpool3 = nn.MaxPool2d(2)
        
        self.conv4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.maxpool4 = nn.MaxPool2d(2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.fc0 = nn.Linear(512, 256)
        self.bn0 = nn.BatchNorm1d(256)
        self.fc1 = nn.Linear(256, 128)
        self.bn11 = nn.BatchNorm1d(128)
        self.dropout = nn.Dropout(0.3)
        self.fc2 = nn.Linear(128, 64)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.functional.relu(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.functional.relu(x)
        x = self.maxpool2(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = nn.functional.relu(x)
        x = self.maxpool3(x)
        
        x = self.conv4(x)
        x = self.bn4(x)
        x = nn.functional.relu(x)
        x = self.maxpool4(x)
        
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.fc0(x)
        x = self.bn0(x)
        x = nn.functional.relu(x)
        x = self.fc1(x)
        x = self.bn11(x)
        x = nn.functional.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x
        

Files already downloaded and verified
Files already downloaded and verified


In [3]:
def train(Extractor, optimizer, criterion, train_loader, device):
    Extractor.train()
    train_loss = 0
    train_acc = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = Extractor(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        train_acc += pred.eq(target.view_as(pred)).sum().item()
    train_loss /= len(train_loader.dataset)
    train_acc /= len(train_loader.dataset)

    return train_loss, train_acc


In [4]:
def test(Extractor, criterion, test_loader, device):
    Extractor.eval()
    test_loss = 0
    test_acc = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = Extractor(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            test_acc += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    test_acc /= len(test_loader.dataset)

    return test_loss, test_acc


In [5]:
def main():
    device = 'cuda:6' if torch.cuda.is_available() else 'cpu'

    Extractor = FeatureExtractor().to(device)
    optimizer = optim.Adam(Extractor.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(30):
        train_loss, train_acc = train(Extractor, optimizer, criterion, train_loader, device)
        test_loss, test_acc = test(Extractor, criterion, test_loader, device)

        print(f'Epoch {epoch+1}: Train Loss: {train_loss:.6f}, Train Acc: {train_acc:.6f}, Test Loss: {test_loss:.6f}, Test Acc: {test_acc:.6f}')

    return Extractor

In [6]:
if __name__ == '__main__':
    trained_extractor = main()

Epoch 1: Train Loss: 0.011299, Train Acc: 0.509800, Test Loss: 0.009002, Test Acc: 0.605600
Epoch 2: Train Loss: 0.007533, Train Acc: 0.663700, Test Loss: 0.007095, Test Acc: 0.687700
Epoch 3: Train Loss: 0.006310, Train Acc: 0.720260, Test Loss: 0.006855, Test Acc: 0.694600
Epoch 4: Train Loss: 0.005606, Train Acc: 0.750080, Test Loss: 0.007604, Test Acc: 0.680600
Epoch 5: Train Loss: 0.005098, Train Acc: 0.776780, Test Loss: 0.005310, Test Acc: 0.767900
Epoch 6: Train Loss: 0.004712, Train Acc: 0.793780, Test Loss: 0.004935, Test Acc: 0.788800
Epoch 7: Train Loss: 0.004378, Train Acc: 0.808140, Test Loss: 0.005309, Test Acc: 0.773300
Epoch 8: Train Loss: 0.004074, Train Acc: 0.821460, Test Loss: 0.004841, Test Acc: 0.798700
Epoch 9: Train Loss: 0.003801, Train Acc: 0.832600, Test Loss: 0.004350, Test Acc: 0.816600
Epoch 10: Train Loss: 0.003654, Train Acc: 0.840220, Test Loss: 0.003979, Test Acc: 0.828400
Epoch 11: Train Loss: 0.003427, Train Acc: 0.848340, Test Loss: 0.004438, Test 

In [7]:
trained_extractor.eval()
for param in trained_extractor.parameters():
    param.requires_grad = False


In [20]:
print(trained_extractor)


FeatureExtractor(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (maxpool4): MaxPool2d(kernel_size=2, strid

In [47]:
# 假设您的已训练好的特征提取器名称为trained_extractor
torch.save(trained_extractor.state_dict(), 'trained_extractor.pth')
