In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as ds
from torchsummary import summary

In [2]:
# setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')


Using device: cuda

NVIDIA GeForce RTX 3060 Ti
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


# Model Set Up

In [83]:
class IsFishClassifier(nn.Module):
    def __init__(self, pretrained_model):
        super(IsFishClassifier, self).__init__()
        self.model = nn.Sequential(*(list(pretrained_model.children())[:-2]))
        self.pool = nn.AdaptiveAvgPool2d((1,1))
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, 2)
        #nn.Sequential(pretrained, nn.Sequential(nn.AdaptiveAvgPool2d((1,1)), nn.Linear(64, 16), nn.Linear(16 , 2, bias = True)))

    def forward(self, x):
        x = self.model(x)
        x = self.pool(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
                    

In [84]:
"""criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"""

'criterion = nn.CrossEntropyLoss()\noptimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)'

# Data Loading

In [85]:
ds.ImageFolder(root="/home/shivaram/DS/Projects/FishID/is_fish_images")

Dataset ImageFolder
    Number of datapoints: 59483
    Root location: /home/shivaram/DS/Projects/FishID/is_fish_images

# Model Testing

In [87]:
model = torch.hub.load("chenyaofo/pytorch-cifar-models", "cifar10_resnet20", pretrained=True)
#summary(model, (3, 32, 32), device="cpu")

Using cache found in /home/shivaram/.cache/torch/hub/chenyaofo_pytorch-cifar-models_master


In [89]:
fish_model = IsFishClassifier(model)
#summary(fish_model, (3, 32, 32), device="cpu")

In [91]:
sample = torch.randn(4, 3, 32, 32)
fish_model(sample)

tensor([[-0.0738,  0.0866],
        [ 0.1429, -0.2087],
        [-0.1219, -0.1098],
        [-0.0896, -0.0839]], grad_fn=<AddmmBackward0>)

# References
- https://discuss.pytorch.org/t/load-only-a-part-of-the-network-with-pretrained-weights/88397/2
- https://discuss.pytorch.org/t/remove-the-fc-layer-from-the-pretrained-resnet50-and-save-the-new-model-file/51124/4
- https://github.com/chenyaofo/pytorch-cifar-models