Skip to content

Commit

Permalink
Merge pull request #409 from pykale/refactor_resnet
Browse files Browse the repository at this point in the history
Refactor resnet
  • Loading branch information
haipinglu committed Oct 4, 2023
2 parents cc561ef + 253f98f commit 1240a4a
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 126 deletions.
182 changes: 58 additions & 124 deletions kale/embed/image_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,44 @@
from torchvision import models


class Flatten(nn.Module):
"""Flatten layer
This module is to replace the last fc layer of the pre-trained model with a flatten layer. It flattens the input
tensor to a 2D vector, which is (B, N). B is the batch size and N is the product of all dimensions except the batch size.
Examples:
>>> x = torch.randn(8, 3, 224, 224)
>>> x = Flatten()(x)
>>> print(x.shape)
>>> (8, 150528)
"""

def __init__(self):
super(Flatten, self).__init__()

def forward(self, x):
return x.view(x.size(0), -1)


class Identity(nn.Module):
"""Identity layer
This module is to replace any unwanted layers in a pre-defined model with an identity layer.
It returns the input tensor as the output.
Examples:
>>> x = torch.randn(8, 3, 224, 224)
>>> x = Identity()(x)
>>> print(x.shape)
>>> (8, 3, 224, 224)
"""

def __init__(self):
super(Identity, self).__init__()

def forward(self, x):
return x


# From FeatureExtractorDigits in adalib
class SmallCNNFeature(nn.Module):
"""
Expand Down Expand Up @@ -158,36 +196,16 @@ class ResNet18Feature(nn.Module):
weights (models.ResNet18_Weights or string): The pretrained weights to use. See
https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet18.html#torchvision.models.ResNet18_Weights
for more details. By default, ResNet18_Weights.DEFAULT will be used.
Note:
Code adapted by pytorch-ada from https://github.com/thuml/Xlearn/blob/master/pytorch/src/network.py
"""

def __init__(self, weights=models.ResNet18_Weights.DEFAULT):
super(ResNet18Feature, self).__init__()
model_resnet18 = models.resnet18(weights=weights)
self.conv1 = model_resnet18.conv1
self.bn1 = model_resnet18.bn1
self.relu = model_resnet18.relu
self.maxpool = model_resnet18.maxpool
self.layer1 = model_resnet18.layer1
self.layer2 = model_resnet18.layer2
self.layer3 = model_resnet18.layer3
self.layer4 = model_resnet18.layer4
self.avgpool = model_resnet18.avgpool
self._out_features = model_resnet18.fc.in_features
self.model = models.resnet18(weights=weights)
self._out_features = self.model.fc.in_features
self.model.fc = Flatten()

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
return self.model(x)

def output_size(self):
return self._out_features
Expand All @@ -201,37 +219,16 @@ class ResNet34Feature(nn.Module):
weights (models.ResNet34_Weights or string): The pretrained weights to use. See
https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet34.html#torchvision.models.ResNet34_Weights
for more details. By default, ResNet34_Weights.DEFAULT will be used.
Note:
Code adapted by pytorch-ada from https://github.com/thuml/Xlearn/blob/master/pytorch/src/network.py
"""

def __init__(self, weights=models.ResNet34_Weights.DEFAULT):
super(ResNet34Feature, self).__init__()
model_resnet34 = models.resnet34(weights=weights)
self.conv1 = model_resnet34.conv1
self.bn1 = model_resnet34.bn1
self.relu = model_resnet34.relu
self.maxpool = model_resnet34.maxpool
self.layer1 = model_resnet34.layer1
self.layer2 = model_resnet34.layer2
self.layer3 = model_resnet34.layer3
self.layer4 = model_resnet34.layer4
self.avgpool = model_resnet34.avgpool
self._out_features = model_resnet34.fc.in_features
self.model = models.resnet34(weights=weights)
self._out_features = self.model.fc.in_features
self.model.fc = Flatten()

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
return self.model(x)

def output_size(self):
return self._out_features
Expand All @@ -245,37 +242,16 @@ class ResNet50Feature(nn.Module):
weights (models.ResNet50_Weights or string): The pretrained weights to use. See
https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet50.html#torchvision.models.ResNet50_Weights
for more details. By default, ResNet50_Weights.DEFAULT will be used.
Note:
Code adapted by pytorch-ada from https://github.com/thuml/Xlearn/blob/master/pytorch/src/network.py
"""

def __init__(self, weights=models.ResNet50_Weights.DEFAULT):
super(ResNet50Feature, self).__init__()
model_resnet50 = models.resnet50(weights=weights)
self.conv1 = model_resnet50.conv1
self.bn1 = model_resnet50.bn1
self.relu = model_resnet50.relu
self.maxpool = model_resnet50.maxpool
self.layer1 = model_resnet50.layer1
self.layer2 = model_resnet50.layer2
self.layer3 = model_resnet50.layer3
self.layer4 = model_resnet50.layer4
self.avgpool = model_resnet50.avgpool
self._out_features = model_resnet50.fc.in_features
self.model = models.resnet50(weights=weights)
self._out_features = self.model.fc.in_features
self.model.fc = Flatten()

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
return self.model(x)

def output_size(self):
return self._out_features
Expand All @@ -289,37 +265,16 @@ class ResNet101Feature(nn.Module):
weights (models.ResNet101_Weights or string): The pretrained weights to use. See
https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet101.html#torchvision.models.ResNet101_Weights
for more details. By default, ResNet101_Weights.DEFAULT will be used.
Note:
Code adapted by pytorch-ada from https://github.com/thuml/Xlearn/blob/master/pytorch/src/network.py
"""

def __init__(self, weights=models.ResNet101_Weights.DEFAULT):
super(ResNet101Feature, self).__init__()
model_resnet101 = models.resnet101(weights=weights)
self.conv1 = model_resnet101.conv1
self.bn1 = model_resnet101.bn1
self.relu = model_resnet101.relu
self.maxpool = model_resnet101.maxpool
self.layer1 = model_resnet101.layer1
self.layer2 = model_resnet101.layer2
self.layer3 = model_resnet101.layer3
self.layer4 = model_resnet101.layer4
self.avgpool = model_resnet101.avgpool
self._out_features = model_resnet101.fc.in_features
self.model = models.resnet101(weights=weights)
self._out_features = self.model.fc.in_features
self.model.fc = Flatten()

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
return self.model(x)

def output_size(self):
return self._out_features
Expand All @@ -333,37 +288,16 @@ class ResNet152Feature(nn.Module):
weights (models.ResNet152_Weights or string): The pretrained weights to use. See
https://pytorch.org/vision/stable/models/generated/torchvision.models.resnet152.html#torchvision.models.ResNet152_Weights
for more details. By default, ResNet152_Weights.DEFAULT will be used.
Note:
Code adapted by pytorch-ada from https://github.com/thuml/Xlearn/blob/master/pytorch/src/network.py
"""

def __init__(self, weights=models.ResNet152_Weights.DEFAULT):
super(ResNet152Feature, self).__init__()
model_resnet152 = models.resnet152(weights=weights)
self.conv1 = model_resnet152.conv1
self.bn1 = model_resnet152.bn1
self.relu = model_resnet152.relu
self.maxpool = model_resnet152.maxpool
self.layer1 = model_resnet152.layer1
self.layer2 = model_resnet152.layer2
self.layer3 = model_resnet152.layer3
self.layer4 = model_resnet152.layer4
self.avgpool = model_resnet152.avgpool
self._out_features = model_resnet152.fc.in_features
self.model = models.resnet152(weights=weights)
self._out_features = self.model.fc.in_features
self.model.fc = Flatten()

def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
return x
return self.model(x)

def output_size(self):
return self._out_features
Expand Down
18 changes: 17 additions & 1 deletion tests/embed/test_image_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch

from kale.embed.image_cnn import (
Flatten,
Identity,
LeNet,
ResNet18Feature,
ResNet34Feature,
Expand Down Expand Up @@ -48,14 +50,28 @@ def test_shapes(param):
model.eval()
output_batch = model(INPUT_BATCH)
assert output_batch.size() == (BATCH_SIZE, out_size)
assert model.output_size() == out_size


def test_lenet_output_shapes():
input_channels = 3
output_channels = 6
additional_layers = 2
lenet = LeNet(input_channels, output_channels, additional_layers)

x = torch.randn(16, 3, 32, 32)
output = lenet(x)
assert output.shape == (16, 24, 4, 4), "Unexpected output shape"


def test_flatten_output_shapes():
flatten = Flatten()
x = torch.randn(16, 3, 32, 32)
output = flatten(x)
assert output.shape == (16, 3072), "Unexpected output shape"


def test_identity_output_shapes():
identity = Identity()
x = torch.randn(16, 3, 32, 32)
output = identity(x)
assert output.shape == (16, 3, 32, 32), "Unexpected output shape"
2 changes: 1 addition & 1 deletion tests/pipeline/test_multi_domain_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_multi_source(method, input_dimension, office_caltech_access, testing_cf
if method == "MFSAN":
train_params["input_dimension"] = input_dimension
if input_dimension == 2:
feature_network = torch.nn.Sequential(*(list(feature_network.children())[:-1]))
feature_network = torch.nn.Sequential(*(list(feature_network.model.children())[:-1]))

model = create_ms_adapt_trainer(
method=method,
Expand Down

0 comments on commit 1240a4a

Please sign in to comment.