In [1]:
import torch
import torch.nn as nn

import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import sys
sys.path.append('/workspace/experiment')
from data.cifar10 import SplitCifar10, train_transform, val_transform
from data.capsule_split import get_splits

import os

import matplotlib.pyplot as plt


In [32]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.05)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

class classifier32(nn.Module):
    def __init__(self, num_classes=2, **kwargs):
        super(self.__class__, self).__init__()
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(3,       64,     3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(64,      64,     3, 1, 1, bias=False)
        self.conv3 = nn.Conv2d(64,     128,     3, 2, 1, bias=False)

        self.conv4 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv5 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv6 = nn.Conv2d(128,    128,     3, 2, 1, bias=False)

        self.conv7 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv8 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv9 = nn.Conv2d(128,    128,     3, 2, 1, bias=False)

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(128)

        self.bn4 = nn.BatchNorm2d(128)
        self.bn5 = nn.BatchNorm2d(128)
        self.bn6 = nn.BatchNorm2d(128)

        self.bn7 = nn.BatchNorm2d(128)
        self.bn8 = nn.BatchNorm2d(128)
        self.bn9 = nn.BatchNorm2d(128)

        self.fc1 = nn.Linear(128, num_classes)
        self.dr1 = nn.Dropout2d(0.2)
        self.dr2 = nn.Dropout2d(0.2)
        self.dr3 = nn.Dropout2d(0.2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        
        self.apply(weights_init)
        
    def block0(self, x):
        x = self.dr1(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.LeakyReLU(0.2)(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.LeakyReLU(0.2)(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = nn.LeakyReLU(0.2)(x)
        return x
    
    def block1(self, x):
        x = self.dr2(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = nn.LeakyReLU(0.2)(x)
        
        x = self.conv5(x)
        x = self.bn5(x)
        x = nn.LeakyReLU(0.2)(x)
        
        x = self.conv6(x)
        x = self.bn6(x)
        x = nn.LeakyReLU(0.2)(x) 
        return x
    
    def block2(self, x):
        x = self.dr3(x)
        x = self.conv7(x)
        x = self.bn7(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv8(x)
        x = self.bn8(x)
        x = nn.LeakyReLU(0.2)(x)
        x = self.conv9(x)
        x = self.bn9(x)
        x = nn.LeakyReLU(0.2)(x)
        return x
        
    def forward(self, x):
        l1 = self.block0(x)
        l2 = self.block1(l1)
        l3 = self.block2(l2)
        
        l3 = self.avgpool(l3)
        l3 = l3.view(x.shape[0], -1)
        y = self.fc1(l3)
        
        return {
            'logit': y,
            'l3': l3,
            'l2': l2,
            'l1': l1,
        }


In [27]:
class FeatureDiscriminator(nn.Module):
    def __init__(self):
        super(self.__class__, self).__init__()
        self.dr = nn.Dropout2d(0.2)
        self.conv1 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(128,    128,     3, 1, 1, bias=False)
        self.conv3 = nn.Conv2d(128,    128,     3, 2, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(128)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(128, 1)
    
    def forward(self, x):
        x = self.dr(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.LeakyReLU(0.2)(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        x = nn.LeakyReLU(0.2)(x)
        
        x = self.conv3(x)
        x = self.bn3(x)
        x = nn.LeakyReLU(0.2)(x)
        
        x = self.avgpool(x)

        logit = self.fc(x.view(x.shape[0], -1))
        
        return logit
        

In [28]:
split = get_splits('cifar10', 0)
known_data = SplitCifar10('/datasets', train=False, transform=val_transform, split=split['known_classes'])
open_data = SplitCifar10('/datasets', train=False, transform=val_transform, split=split['unknown_classes'])
known_loader = DataLoader(known_data, batch_size=32, shuffle=False, num_workers=4)
open_loader = DataLoader(open_data, batch_size=32, shuffle=False, num_workers=4)
# test_input, test_target = next(iter(known_loader))

In [33]:
model = classifier32(6)
desc = FeatureDiscriminator()

In [34]:
test_input ,test_target = next(iter(known_loader))

In [35]:
out = model(test_input)

In [77]:
y = torch.randint(0, 6, (32,))

In [79]:
def cal_index(y):
    batch_size = y.size(0)
    new_index = torch.randperm(batch_size).type_as(y)
    newY = y[new_index]
    mask = (newY == y)
    while mask.any().item():
        newY[mask] = torch.randint(0, 6, (torch.sum(mask),)).type_as(y)
        mask = (newY == y)
    return newY
    

In [13]:
a = torch.randint(0, 9, (32,))

In [14]:
a > 6

tensor([False, False, False,  True, False,  True, False, False,  True, False,
        False,  True, False,  True, False, False, False, False, False, False,
        False,  True, False,  True,  True, False, False, False,  True, False,
         True,  True])

In [1]:
out.keys()

NameError: name 'out' is not defined