[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Spandan-Madan/generalization_to_OOD_category_viewpoint_cominations/blob/main/demos/increasing_in_distribution_combinations.ipynb)

# Overview


This demo shows the impact of increasing in-distribution combinations on out-of-distribution generalization. Specifically, we reproduce the results for the MNIST Rotation dataset on the `SHARED` architecture.

As shown below, this architecture enforces parameter sharing between the two tasks (category prediction and viewpoint prediction).

As described in the paper, our results show that increasing data diversity (i.e. in-distribution combinations) leads to a substantial increase in out-of-distribution performance eventhough total number of training images (dataset size) is held constant.

![Shared Architecture](https://github.com/Spandan-Madan/generalization_to_OOD_category_viewpoint_cominations/blob/main/docs/images/Shared.png?raw=1)

In [None]:
import os
def create_folder(path):
    if not os.path.isdir(path):
        os.mkdir(path)

If running on google colab, the below code does the following:
- clone repo
- set up necessary folders
- download MNIST Rotation Dataset at appropriate place
- unzip MNIST Rotation

#### If you're not running on colab, please follow download instructions to get the mnist_rotaiton dataset using:

```
cd utils
bash download_mnist_rotation.sh
```

#### If not using google colab, please proceed below only after downloading the dataset

In [None]:
if 'google.colab' in str(get_ipython()):
    print('Cloning code base to colab....')
    !git clone https://github.com/Spandan-Madan/generalization_to_OOD_category_viewpoint_cominations.git
    !cd generalization_to_OOD_category_viewpoint_cominations/utils && bash download_mnist_rotation.sh
    CODE_ROOT = "generalization_to_OOD_category_viewpoint_cominations/"
else:
    CODE_ROOT = '..'

In [None]:
from __future__ import print_function, division
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import copy
import os
from PIL import ImageFile
import random
ImageFile.LOAD_TRUNCATED_IMAGES = True
import argparse
import pickle
import sys
sys.path.append('%s/res/'%CODE_ROOT)
from models.models import get_model
from loader.loader import get_loader

In [None]:
from tqdm.notebook import tqdm
from IPython.display import clear_output
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context("poster")
sns.set_palette("Set1", 8, .75)
sns.despine()

This demo trains networks with 1, 3, 6 and 9 in-distribution combinations of the MNIST-Rotation dataset, and plots performance on out-of-distribution combinations from the dataset.

To run on a different architecture, please change the `ARCH` variable below.

In [None]:
DATASET_NAMES = ['mnist_rotation_one_by_nine', 'mnist_rotation_three_by_nine',
                 'mnist_rotation_six_by_nine', 'mnist_rotation_nine_by_nine']
NUM_EPOCHS = 5
BATCH_SIZE = 100
ARCHS =['EARLY_BRANCHING_COMBINED','SPLIT_AFTER_ONE_BLOCK', 'LATE_BRANCHING_COMBINED', 'LATE_BRANCHING_COMBINED_WIDER',  'LATE_BRANCHING_COMBINED_ONE_FOURTH', 'MULTITASK_INCEPTION_WIDE', 'MULTITASK_RESNEXT_WIDE',  'MULTITASK_RESNEXT', 'LATE_BRANCHING_COMBINED_HALF','Multitask_Resnet_Early_New',  'SPLIT_AFTER_THREE_BLOCKS', 'Two_Block_Encoder_Long_Decoder']

image_transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])


GPU = 1

In [None]:
NUM_CLASSES = (10,10,10,10)
loader_new = get_loader('multi_attribute_loader_file_list_mnist_rotation')

file_list_root = '%s/dataset_lists/mnist_rotation_lists/'%CODE_ROOT
att_path = '%s/dataset_lists/combined_attributes.p'%CODE_ROOT

In [None]:
shuffles = {'train':True,'val':True,'test':False}

In [None]:
data_dir = '%s/data/'%CODE_ROOT

In [None]:
all_dsets = {}
all_dset_loaders = {}
all_dset_sizes = {}

for DATASET_NAME in DATASET_NAMES:
    file_lists = {}
    dsets = {}
    dset_loaders = {}
    dset_sizes = {}
    for phase in ['train','val','test']:
        file_lists[phase] = "%s/%s_list_%s.txt"%(file_list_root,phase,DATASET_NAME)
        dsets[phase] = loader_new(file_lists[phase],att_path, image_transform, data_dir)
        dset_loaders[phase] = torch.utils.data.DataLoader(dsets[phase], batch_size=BATCH_SIZE, shuffle = shuffles[phase], num_workers=2,drop_last=True)
        dset_sizes[phase] = len(dsets[phase])
    all_dsets[DATASET_NAME] = dsets
    all_dset_loaders[DATASET_NAME] = dset_loaders
    all_dset_sizes[DATASET_NAME] = dset_sizes

In [None]:
multi_losses = [nn.CrossEntropyLoss(),nn.CrossEntropyLoss(),nn.CrossEntropyLoss(),nn.CrossEntropyLoss()]

In [None]:
class Inception(nn.Module):
    def __init__(self, num_class=3, training=True):
        super(Inception, self).__init__()
        model = models.inception_v3(pretrained=True)
        self.inception_conv1 = model.Conv2d_1a_3x3
        self.inception_conv2 = model.Conv2d_2a_3x3
        self.inception_conv3 = model.Conv2d_2b_3x3
        self.maxpool1 = model.maxpool1
        self.inception_conv4 = model.Conv2d_3b_1x1
        self.inception_conv5 = model.Conv2d_4a_3x3
        self.maxpool2 = model.maxpool2
        self.mixed1 = model.Mixed_5b
        self.mixed2 = model.Mixed_5c
        self.mixed3 = model.Mixed_5d
        self.mixed4 = model.Mixed_6a
        self.mixed5 = model.Mixed_6b
        self.mixed6 = model.Mixed_6c
        self.mixed7 = model.Mixed_6d
        self.mixed8 = model.Mixed_6e
        if training:
            self.auxlogits = model.AuxLogits.conv0
            self.auxlogits1 = model.AuxLogits.conv1
            self.auxlogits2 = nn.Linear(768, num_class)
        self.mixed9 = model.Mixed_7a
        self.mixed10 = model.Mixed_7b
        self.mixed11 = model.Mixed_7c
        self.avgpool = model.avgpool
        self.fc = nn.Linear(2048, 2048)
        self.bnlast = nn.BatchNorm1d(2048)
        self.relulast = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout()

        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 256)
        self.classifier = nn.Linear(256, num_class)

        self.training = training

    def forward(self, x):
        assert x.size(1) == 3
        x = self.inception_conv1(x)
        x = self.inception_conv2(x)
        x = self.inception_conv3(x)
        x = self.maxpool1(x)
        x = self.inception_conv4(x)
        x = self.inception_conv5(x)
        x = self.maxpool2(x)
        x = self.mixed1(x)
        x = self.mixed2(x)
        x = self.mixed3(x)
        x = self.mixed4(x)
        x = self.mixed5(x)
        x = self.mixed6(x)
        x = self.mixed7(x)
        x = self.mixed8(x)

        if self.training:
            aux = self.auxlogits(x)
            aux = self.auxlogits1(aux)
            aux = aux.view(aux.size(0), -1)
            aux = self.auxlogits2(aux)
        x = self.mixed9(x)
        x = self.mixed10(x)
        x = self.mixed11(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.bnlast(x)
        x = self.relulast(x)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.classifier(x)
        if self.training:
            return x, aux
        else:
            return x


class SEBlock(nn.Module):
    def __init__(self, c_in):
        super().__init__()
        self.globalavgpooling = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(c_in, max(1, c_in // 16))
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(max(1, c_in // 16), c_in)
        self.sigmoid = nn.Sigmoid()
        self.c_in = c_in

    def forward(self, x):
        assert self.c_in == x.size(1)
        x = self.globalavgpooling(x)
        x = x.squeeze()
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = self.sigmoid(x)
        return x


class CustomDownsampleBlock(nn.Module):
    def __init__(self, in_channels, intermediate_channels, out_channels):
        super(CustomDownsampleBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, intermediate_channels, 1)
        self.conv2 = nn.Conv2d(intermediate_channels, intermediate_channels, 3, 2, 1)
        self.conv3 = nn.Conv2d(intermediate_channels, out_channels, 1)
        self.avgpool = nn.AvgPool2d(2, 2, ceil_mode=True)
        self.conv4 = nn.Conv2d(in_channels, out_channels, 1)

    def forward(self, x):
        branch = x
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        branch = self.avgpool(branch)
        branch = self.conv4(branch)
        return x + branch


class CustomResNet50(nn.Module):
    def __init__(self, num_class=3, intermediate_channels=[64, 128, 256, 512]):
        super(CustomResNet50, self).__init__()
        model = models.resnet50(pretrained=True)
        self.conv0 = model.conv1
        self.bn0 = model.bn1
        self.relu0 = model.relu
        self.pooling0 = model.maxpool
        self.layer1 = model.layer1
        model.layer1[0] = CustomDownsampleBlock(64, intermediate_channels[0], 256)

        self.layer2 = model.layer2
        model.layer2[0] = CustomDownsampleBlock(256, intermediate_channels[1], 512)

        self.layer3 = model.layer3
        model.layer3[0] = CustomDownsampleBlock(512, intermediate_channels[2], 1024)

        self.layer4 = model.layer4
        model.layer4[0] = CustomDownsampleBlock(1024, intermediate_channels[3], 2048)

        self.avgpool = model.avgpool

        self.fc = nn.Linear(2048, 2048)
        self.bnlast = nn.BatchNorm1d(2048)
        self.relulast = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout()

        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 256)
        self.classifier = nn.Linear(256, num_class)

    def forward(self, x):
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu0(x)
        x = self.pooling0(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.bnlast(x)
        x = self.relulast(x)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.classifier(x)
        return x


class Baseline(nn.Module):
    def __init__(self, num_class=3, isCustom=False, backend="resnet50"):
        super(Baseline, self).__init__()
        model = getattr(models, backend)(pretrained=True)
        self.conv0 = model.conv1
        self.bn0 = model.bn1
        self.relu0 = model.relu
        self.pooling0 = model.maxpool
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        self.avgpool = model.avgpool

        self.fc = nn.Linear(2048, 2048)
        self.bnlast = nn.BatchNorm1d(2048)
        self.relulast = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout()

        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 256)
        self.classifier = nn.Linear(256, num_class)

        self.isCustom = isCustom

    def forward(self, x):
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu0(x)
        x = self.pooling0(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.bnlast(x)
        x = self.relulast(x)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.classifier(x)
        if self.isCustom:
            x = torch.sigmoid(x).squeeze()
        return x


class ResNet50DC5(nn.Module):
    def __init__(self, dilation=True, num_class=3):
        super(ResNet50DC5, self).__init__()
        model = models.resnet50(pretrained=True,
                                replace_stride_with_dilation=[False, False, dilation])
        self.conv0 = model.conv1
        self.bn0 = model.bn1
        self.relu0 = model.relu
        self.pooling0 = model.maxpool
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4
        self.avgpool = model.avgpool

        self.fc = nn.Linear(2048, 2048)
        self.bnlast = nn.BatchNorm1d(2048)
        self.relulast = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout()

        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 256)
        self.classifier = nn.Linear(256, num_class)

    def forward(self, x):
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu0(x)
        x = self.pooling0(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.bnlast(x)
        x = self.relulast(x)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.classifier(x)
        return x


class ResNet50(nn.Module):
    def __init__(self, num_class=3):
        super(ResNet50, self).__init__()
        model = models.resnet50(pretrained=True)
        self.conv0 = model.conv1
        self.bn0 = model.bn1
        self.relu0 = model.relu
        self.pooling0 = model.maxpool
        self.layer1 = model.layer1
        for i in range(len(self.layer1)):
            for name, module in self.layer1[i].named_modules():
                if "bn3" in name:
                    nn.init.constant_(module.weight, 0.)
        self.layer2 = model.layer2
        for i in range(len(self.layer2)):
            for name, module in self.layer2[i].named_modules():
                if "bn3" in name:
                    nn.init.constant_(module.weight, 0.)
        self.layer3 = model.layer3
        for i in range(len(self.layer3)):
            for name, module in self.layer3[i].named_modules():
                if "bn3" in name:
                    nn.init.constant_(module.weight, 0.)
        self.layer4 = model.layer4
        for i in range(len(self.layer4)):
            for name, module in self.layer4[i].named_modules():
                if "bn3" in name:
                    nn.init.constant_(module.weight, 0.)
        self.avgpool = model.avgpool

        self.fc = nn.Linear(2048, 2048)
        self.bnlast = nn.BatchNorm1d(2048)
        self.relulast = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout()

        self.fc1 = nn.Linear(2048, 512)
        self.fc2 = nn.Linear(512, 256)
        self.classifier = nn.Linear(256, num_class)

    def forward(self, x):
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu0(x)
        x = self.pooling0(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.bnlast(x)
        x = self.relulast(x)
        x = self.dropout(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.classifier(x)
        return x


class SEDense34(nn.Module):
    def __init__(self, num_class=3, needs_norm=True):
        super().__init__()
        model = models.resnet34(pretrained=True)
        self.conv0 = model.conv1
        self.bn0 = model.bn1
        self.relu0 = model.relu
        self.pooling0 = model.maxpool
        # layer1
        self.bottleneck11 = model.layer1[0]
        self.bottleneck12 = model.layer1[1]
        self.bottleneck13 = model.layer1[2]

        self.seblock11 = SEBlock(64)
        self.seblock12 = SEBlock(64)
        self.seblock13 = SEBlock(64)
        # layer2
        self.bottleneck21 = model.layer2[0]
        self.bottleneck22 = model.layer2[1]
        self.bottleneck23 = model.layer2[2]
        self.bottleneck24 = model.layer2[3]

        self.auxconv1 = nn.Conv2d(64, 128, 1, 2, 0)
        self.optionalbn1 = nn.BatchNorm2d(128)
        self.seblock21 = SEBlock(128)
        self.seblock22 = SEBlock(128)
        self.seblock23 = SEBlock(128)
        self.seblock24 = SEBlock(128)
        # layer3
        self.bottleneck31 = model.layer3[0]
        self.bottleneck32 = model.layer3[1]
        self.bottleneck33 = model.layer3[2]
        self.bottleneck34 = model.layer3[3]
        self.bottleneck35 = model.layer3[4]
        self.bottleneck36 = model.layer3[5]

        self.auxconv2 = nn.Conv2d(128, 256, 1, 2, 0)
        self.optionalbn2 = nn.BatchNorm2d(256)
        self.seblock31 = SEBlock(256)
        self.seblock32 = SEBlock(256)
        self.seblock33 = SEBlock(256)
        self.seblock34 = SEBlock(256)
        self.seblock35 = SEBlock(256)
        self.seblock36 = SEBlock(256)
        # layer4
        self.bottleneck41 = model.layer4[0]
        self.bottleneck42 = model.layer4[1]
        self.bottleneck43 = model.layer4[2]

        self.auxconv3 = nn.Conv2d(256, 512, 1, 2, 0)
        self.optionalbn3 = nn.BatchNorm2d(512)
        self.seblock41 = SEBlock(512)
        self.seblock42 = SEBlock(512)
        self.seblock43 = SEBlock(512)

        self.avgpool = model.avgpool
        self.fc = nn.Linear(512, 128)
        self.bnlast = nn.BatchNorm1d(128)
        self.relulast = nn.ReLU(inplace=True)
        self.dropout = nn.Dropout()

        self.classifier = nn.Linear(128, num_class)

        self.norm = needs_norm

    def forward(self, x):
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu0(x)
        x = self.pooling0(x)

        branch1 = x
        x = self.bottleneck11(x)
        scale1 = self.seblock11(x)
        x = scale1 * x + branch1

        branch2 = x
        x = self.bottleneck12(x)
        scale2 = self.seblock12(x)
        x = scale2 * x + branch2

        branch3 = x
        x = self.bottleneck13(x)
        scale3 = self.seblock13(x)
        x = scale3 * x + branch3

        branch4 = x
        x = self.bottleneck21(x)
        scale4 = self.seblock21(x)
        if self.norm:
            x = scale4 * x + self.optionalbn1(self.auxconv1(branch4))
        else:
            x = scale4 * x + self.auxconv1(branch4)

        branch5 = x
        x = self.bottleneck22(x)
        scale5 = self.seblock22(x)
        x = scale5 * x + branch5

        branch6 = x
        x = self.bottleneck23(x)
        scale6 = self.seblock23(x)
        x = scale6 * x + branch6

        branch7 = x
        x = self.bottleneck24(x)
        scale7 = self.seblock24(x)
        x = scale7 * x + branch7

        branch8 = x
        x = self.bottleneck31(x)
        scale8 = self.seblock31(x)
        if self.norm:
            x = scale8 * x + self.optionalbn2(self.auxconv2(branch8))
        else:
            x = scale8 * x + self.auxconv2(branch8)

        branch9 = x
        x = self.bottleneck32(x)
        scale9 = self.seblock32(x)
        x = scale9 * x + branch9

        branch10 = x
        x = self.bottleneck33(x)
        scale10 = self.seblock33(x)
        x = scale10 * x + branch10

        branch11 = x
        x = self.bottleneck34(x)
        scale11 = self.seblock34(x)
        x = scale11 * x + branch11

        branch12 = x
        x = self.bottleneck35(x)
        scale12 = self.seblock35(x)
        x = scale12 * x + branch12

        branch13 = x
        x = self.bottleneck36(x)
        scale13 = self.seblock36(x)
        x = scale13 * x + branch13

        branch14 = x
        x = self.bottleneck41(x)
        scale14 = self.seblock41(x)
        if self.norm:
            x = scale14 * x + self.optionalbn3(self.auxconv3(branch14))
        else:
            x = scale14 * x + self.auxconv3(branch14)

        branch15 = x
        x = self.bottleneck42(x)
        scale15 = self.seblock42(x)
        x = scale15 * x + branch15

        branch16 = x
        x = self.bottleneck43(x)
        scale16 = self.seblock43(x)
        x = scale16 * x + branch16

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        x = self.bnlast(x)
        x = self.relulast(x)
        x = self.dropout(x)
        x = self.classifier(x)
        return x


class SEDense18(nn.Module):
    def __init__(self, num_class=3, needs_norm=True):
        super().__init__()
        model = models.resnet18(pretrained=True)
        self.conv0 = model.conv1
        self.bn0 = model.bn1
        self.relu0 = model.relu
        self.pooling0 = model.maxpool
        self.basicBlock11 = model.layer1[0]
        self.seblock1 = SEBlock(64)

        self.basicBlock12 = model.layer1[1]
        self.seblock2 = SEBlock(64)

        self.basicBlock21 = model.layer2[0]
        self.seblock3 = SEBlock(128)
        self.ancillaryconv3 = nn.Conv2d(64, 128, 1, 2, 0)
        self.optionalNorm2dconv3 = nn.BatchNorm2d(128)

        self.basicBlock22 = model.layer2[1]
        self.seblock4 = SEBlock(128)

        self.basicBlock31 = model.layer3[0]
        self.seblock5 = SEBlock(256)
        self.ancillaryconv5 = nn.Conv2d(128, 256, 1, 2, 0)
        self.optionalNorm2dconv5 = nn.BatchNorm2d(256)

        self.basicBlock32 = model.layer3[1]
        self.seblock6 = SEBlock(256)

        self.basicBlock41 = model.layer4[0]
        # last stride = 1
        self.basicBlock41.conv1 = nn.Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False,
                                            device="cuda:0")
        self.basicBlock41.downsample[0] = nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False,
                                                    device="cuda:0")
        self.seblock7 = SEBlock(512)
        self.ancillaryconv7 = nn.Conv2d(256, 512, 1, 1, 0)
        self.optionalNorm2dconv7 = nn.BatchNorm2d(512)

        self.basicBlock42 = model.layer4[1]
        self.seblock8 = SEBlock(512)

        self.avgpooling = model.avgpool
        # self.fc = nn.Linear(512, num_class)
        self.classifier = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(256, num_class),
        )
        self.needs_norm = needs_norm

    def forward(self, x):
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu0(x)
        x = self.pooling0(x)
        branch1 = x
        x = self.basicBlock11(x)
        scale1 = self.seblock1(x)
        x = scale1 * x + branch1

        branch2 = x
        x = self.basicBlock12(x)
        scale2 = self.seblock2(x)
        x = scale2 * x + branch2

        branch3 = x
        x = self.basicBlock21(x)
        scale3 = self.seblock3(x)
        if self.needs_norm:
            x = scale3 * x + self.optionalNorm2dconv3(self.ancillaryconv3(branch3))
        else:
            x = scale3 * x + self.ancillaryconv3(branch3)

        branch4 = x
        x = self.basicBlock22(x)
        scale4 = self.seblock4(x)
        x = scale4 * x + branch4

        branch5 = x
        x = self.basicBlock31(x)
        scale5 = self.seblock5(x)
        if self.needs_norm:
            x = scale5 * x + self.optionalNorm2dconv5(self.ancillaryconv5(branch5))
        else:
            x = scale5 * x + self.ancillaryconv5(branch5)

        branch6 = x
        x = self.basicBlock32(x)
        scale6 = self.seblock6(x)
        x = scale6 * x + branch6

        branch7 = x
        x = self.basicBlock41(x)
        scale7 = self.seblock7(x)
        if self.needs_norm:
            x = scale7 * x + self.optionalNorm2dconv7(self.ancillaryconv7(branch7))
        else:
            x = scale7 * x + self.ancillaryconv7(branch7)

        branch8 = x
        x = self.basicBlock42(x)
        scale8 = self.seblock8(x)
        x = scale8 * x + branch8

        x = self.avgpooling(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x

In [None]:
def weight_scheduler(epoch_num, task):
    if task == 'shared':
        return [0.0,1.0,0.0,1.0]
    elif task == 'viewpoint':
        return [0.0,1.0,0.0,0.0]
    elif task == 'category':
        return [0.0,0.0,0.0,1.0]

In [None]:
def train_epoch(dset_loaders, dset_sizes, model, task, optimizer):
    model.train()
    torch.set_grad_enabled(True)
    phase = 'train'
    
    weights = weight_scheduler(epoch, task)
    iters = 0
    phase_epoch_corrects = [0,0,0,0]
    phase_epoch_loss = 0
    
    for data in dset_loaders[phase]:
        inputs, labels_all, paths = data
        inputs = Variable(inputs.float().cuda())

        optimizer.zero_grad()
        model_outs = model(inputs)
        calculated_loss = 0
        batch_corrects = [0,0,0,0]
        
        for i in range(4):
            labels = labels_all[:,i]
            if GPU:
                labels = Variable(labels.long().cuda())
            loss = multi_losses[i]
            outputs = model_outs[i]
            calculated_loss += weights[i] * loss(outputs,labels)
            _, preds = torch.max(outputs.data, 1)
            batch_corrects[i] = torch.sum(preds == labels.data)
            phase_epoch_corrects[i] += batch_corrects[i]

        
        phase_epoch_loss += calculated_loss
        calculated_loss.backward()
        optimizer.step()
        iters += 1
    epoch_loss = phase_epoch_loss/dset_sizes[phase]
    # print('Train loss:%s'%epoch_loss)
    epoch_accs = [float(i)/dset_sizes[phase] for i in phase_epoch_corrects]

    if task == 'shared':
        epoch_gm = np.sqrt(epoch_accs[1] * epoch_accs[3])
    elif task == 'viewpoint':
        epoch_gm = epoch_accs[1]
    elif task == 'category':
        epoch_gm = epoch_accs[3]
    
    return model, epoch_loss, epoch_gm

In [None]:
def test_epoch(dset_loaders, dset_sizes, model, best_model, best_test_loss, best_test_gm, task):
    model.eval()
    torch.set_grad_enabled(False)
    phase = 'val'
    weights = weight_scheduler(epoch, task)
    iters = 0
    phase_epoch_corrects = [0,0,0,0]
    phase_epoch_loss = 0
    
    for data in dset_loaders[phase]:
        inputs, labels_all, paths = data
        inputs = Variable(inputs.float().cuda())
        model_outs = model(inputs)
        calculated_loss = 0
        batch_corrects = [0,0,0,0]
        
        for i in range(4):
            labels = labels_all[:,i]
            if GPU:
                labels = Variable(labels.long().cuda())
            loss = multi_losses[i]
            outputs = model_outs[i]
            calculated_loss += weights[i] * loss(outputs,labels)
            _, preds = torch.max(outputs.data, 1)
            batch_corrects[i] = torch.sum(preds == labels.data)
            phase_epoch_corrects[i] += batch_corrects[i]


        phase_epoch_loss += calculated_loss
        iters += 1
    epoch_loss = phase_epoch_loss/dset_sizes[phase]
    # print('Test loss:%s'%epoch_loss)
    epoch_accs = [float(i)/dset_sizes[phase] for i in phase_epoch_corrects]
    
    if task == 'shared':
        epoch_gm = np.sqrt(epoch_accs[1] * epoch_accs[3])
    elif task == 'viewpoint':
        epoch_gm = epoch_accs[1]
    elif task == 'category':
        epoch_gm = epoch_accs[3]
    
    if epoch_loss < best_test_loss:
        best_model = model
        best_test_loss = epoch_loss
        best_test_gm = epoch_gm
    
    return best_model, epoch_loss, epoch_gm, best_test_loss, best_test_gm

In [None]:
def unseen_test_epoch(dset_loaders, dset_sizes, model, task):
    model.eval()
    torch.set_grad_enabled(False)
    phase = 'test'

    weights = weight_scheduler(epoch, task)
    iters = 0
    phase_epoch_corrects = [0,0,0,0]
    phase_epoch_loss = 0
    
    for data in dset_loaders[phase]:
        inputs, labels_all, paths = data
        inputs = Variable(inputs.float().cuda())
        model_outs = model(inputs)
        calculated_loss = 0
        batch_corrects = [0,0,0,0]
        
        for i in range(4):
            labels = labels_all[:,i]
            if GPU:
                labels = Variable(labels.long().cuda())
            loss = multi_losses[i]
            outputs = model_outs[i]
            calculated_loss += weights[i] * loss(outputs,labels)
            _, preds = torch.max(outputs.data, 1)
            batch_corrects[i] = torch.sum(preds == labels.data)
            phase_epoch_corrects[i] += batch_corrects[i]


        phase_epoch_loss += calculated_loss
        iters += 1
    epoch_loss = phase_epoch_loss/dset_sizes[phase]
    epoch_accs = [float(i)/dset_sizes[phase] for i in phase_epoch_corrects]
    
    if task == 'shared':
        epoch_gm = np.sqrt(epoch_accs[1] * epoch_accs[3])
    elif task == 'viewpoint':
        epoch_gm = epoch_accs[1]
    elif task == 'category':
        epoch_gm = epoch_accs[3]
    
    return epoch_loss, epoch_gm

In [None]:
plt.rc('xtick', labelsize=14) 
plt.rc('ytick', labelsize=14) 

In [None]:
DATASET_NAMES

In [None]:
dataset_titles = {}
dataset_titles['mnist_rotation_one_by_nine'] = "10% combinations seen"
dataset_titles['mnist_rotation_three_by_nine'] = "30% combinations seen"
dataset_titles['mnist_rotation_six_by_nine'] = "60% combinations seen"
dataset_titles['mnist_rotation_nine_by_nine'] = "90% combinations seen"

In [None]:
for ARCH in ARCHS:
  all_train_all = {}
  all_train_all[ARCH + ' ' + 'shared'] = [0]
  all_train_all[ARCH + ' ' +'separate'] = [0]

  all_test_all = {}
  all_test_all[ARCH + ' ' +'shared'] = [0]
  all_test_all[ARCH + ' ' +'separate'] = [0]

  all_unseen_test_all = {}
  all_unseen_test_all[ARCH + ' ' +'shared'] = [0]
  all_unseen_test_all[ARCH + ' ' + 'separate'] = [0]

  for DATASET_NAME in DATASET_NAMES:
      print('______________________________________________________')
      print('model_name:',ARCH,'  comb_name', dataset_titles[DATASET_NAME])
      dsets = all_dsets[DATASET_NAME]
      dset_loaders = all_dset_loaders[DATASET_NAME]
      dset_sizes = all_dset_sizes[DATASET_NAME]
      
      models = {}

      models['shared']= get_model(ARCH,NUM_CLASSES)
      models['viewpoint']= get_model(ARCH,NUM_CLASSES)
      models['category']= get_model(ARCH,NUM_CLASSES)

      models['shared'].cuda();
      models['viewpoint'].cuda();
      models['category'].cuda();

      best_models = {}
      best_models['shared'] = models['shared']
      best_models['viewpoint'] = models['viewpoint']
      best_models['category'] = models['category']

      best_test_loss = 100
      best_test_gm = 0

      all_train_gms = {}
      all_train_gms['shared'] = [0]
      all_train_gms['separate'] = [0]

      all_test_gms = {}
      all_test_gms['shared'] = [0]
      all_test_gms['separate'] = [0]

      all_unseen_test_gms = {}
      all_unseen_test_gms['shared'] = [0]
      all_unseen_test_gms['separate'] = [0]

      optimizers = {}
      optimizers['shared'] = optim.Adam(models['shared'].parameters(), lr=0.001)
      optimizers['viewpoint'] = optim.Adam(models['viewpoint'].parameters(), lr=0.001)
      optimizers['category'] = optim.Adam(models['category'].parameters(), lr=0.001)
      for epoch in range(NUM_EPOCHS):
          train_gm_separate = 1
          test_gm_separate = 1
          unseen_test_gm_separate = 1

          for TASK in ['viewpoint','category','shared']:
              print('Epoch: %s, Task: %s'%(epoch,TASK))
              print('---------')
              models[TASK], train_loss, train_gm = train_epoch(dset_loaders, dset_sizes, models[TASK], TASK, optimizers[TASK])
              best_models[TASK], test_loss, test_gm, best_test_loss, best_test_gm = test_epoch(dset_loaders, dset_sizes, models[TASK], best_models[TASK], best_test_loss, best_test_gm, TASK)
              unseen_test_loss, unseen_test_gm = unseen_test_epoch(dset_loaders, dset_sizes, models[TASK], TASK)

              if TASK != 'shared':
                  train_gm_separate = train_gm_separate * train_gm
                  test_gm_separate = test_gm_separate * test_gm
                  unseen_test_gm_separate = unseen_test_gm_separate * test_gm

          all_train_gms['separate'].append(np.sqrt(train_gm_separate))
          all_test_gms['separate'].append(np.sqrt(test_gm_separate))
          all_unseen_test_gms['separate'].append(np.sqrt(unseen_test_gm_separate))
          all_train_gms['shared'].append(train_gm)
          all_test_gms['shared'].append(test_gm)
          all_unseen_test_gms['shared'].append(np.sqrt(unseen_test_gm))



      all_train_all[ARCH + ' ' +'separate'].append(np.sqrt(train_gm_separate))
      all_test_all[ARCH + ' ' +'separate'].append(np.sqrt(test_gm_separate))
      all_unseen_test_all[ARCH + ' ' +'separate'].append(np.sqrt(unseen_test_gm_separate))
      all_train_all[ARCH + ' ' +'shared'].append(train_gm)
      all_test_all[ARCH + ' ' +'shared'].append(test_gm)
      all_unseen_test_all[ARCH + ' ' +'shared'].append(np.sqrt(unseen_test_gm))
      

      



      # fig,ax = plt.subplots(1, 3, figsize=(18,6))
      # fig.suptitle(dataset_titles[DATASET_NAME], fontsize = 30)
      # l1 = ax[0].plot(all_train_gms['separate'], color = 'blue', marker = 'o', markersize=5)[0]
      # l2 = ax[0].plot(all_train_gms['shared'], color = 'red', marker = 'o', markersize=5)[0]
      # ax[0].set_title('Train Accuracy', fontsize=12)
      # line_labels = ["Separate", "Shared"]

      # ax[1].plot(all_test_gms['separate'], color = 'blue', marker = 'o', markersize=5)
      # ax[1].plot(all_test_gms['shared'], color = 'red', marker = 'o', markersize=5)
      # ax[1].set_title('Test Accuracy on Seen \n Category-Viewpoint Combinations', fontsize=12)

      # ax[2].plot(all_unseen_test_gms['separate'], color = 'blue', marker = 'o', markersize=5)
      # ax[2].plot(all_unseen_test_gms['shared'], color = 'red', marker = 'o', markersize=5)
      # ax[2].set_title('Test Accuracy on Unseen \n Category-Viewpoint Combinations', fontsize=12)
      # fig.legend([l1, l2],     # The line objects
      #         labels=line_labels,   # The labels for each line
      #         loc="center right",   # Position of legend
      #         borderaxespad=0.2,    # Small spacing around legend box
      #         prop={"size":20})
      # plt.subplots_adjust(right=0.85, top =0.80)
      # plt.show()

In [None]:
percentage = [0.1, 0.3, 0.6, 0.9]

In [None]:
plt.figure(figsize = (10,8))
label = []
for key, value in all_train_all.items():
    sns.lineplot(y = value, x =  percentage)
    label.append(key)
plt.legend(label)
plt.xlabel('percentage of combination')
plt.ylabel('acc')
plt.title('train acc of different CNN models performance om Out-Of-Distribution Category-Viewpoint Generalization.')

In [None]:
plt.figure(figsize = (10,8))
label = []
for key, value in all_test_all.items():
    sns.lineplot(y = value, x =  percentage)
    label.append(key)
plt.legend(label)
plt.xlabel('percentage of combination')
plt.ylabel('acc')
plt.title('test acc of different CNN models performance om Out-Of-Distribution Category-Viewpoint Generalization.')

In [None]:
plt.figure(figsize = (10,8))
label = []
for key, value in all_unseen_test_all.items():
    sns.lineplot(y = value, x =  percentage)
    label.append(key)
plt.legend(label)
plt.xlabel('percentage of combination')
plt.ylabel('acc')
plt.title('unseen_test acc of different CNN models performance om Out-Of-Distribution Category-Viewpoint Generalization.')