In [2]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import torchmetrics

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"

def conv_res(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=1, bias=False)

# Residual block
class res_block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(res_block, self).__init__()
        self.conv1 = conv_res(in_channels, out_channels, stride)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv_res(out_channels, out_channels)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
        out = self.relu(self.bn(self.conv1(x)))
        out = self.bn1(self.conv2(out))
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# ResNet
class res_net(nn.Module):
    def __init__(self, block, layers, triplet=False, num_classes=10):
        super(res_net, self).__init__()
        self.in_channels = 16
        self.triplet = triplet
        self.conv = conv_res(3, 8)
        self.conv1 = conv_res(8, 16)
        self.bn = nn.BatchNorm2d(8)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.rep_layer(block, 16, layers[0])
        self.layer2 = self.rep_layer(block, 32, layers[1], 2)
        self.layer3 = self.rep_layer(block, 64, layers[2], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc1 = nn.Linear(576, 64)
        self.fc2 = nn.Linear(64, num_classes)
        self.sigmoid = nn.Sigmoid()

        
    def rep_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv_res(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)
    
    def forward_pass(self,x):
        out = self.relu(self.bn(self.conv(x)))
        out = self.relu(self.bn1(self.conv1(out)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc1(out)
        
        return out
    
#     def dual_pass(self, x1, x2):
#         return self.sigmoid(self.forward_pass(x1)), self.sigmoid(self.forward_pass(x2))
        
    def forward(self, x):
       
        if self.triplet:

            return self.forward_pass(x)
        
        else:
            out = self.relu(self.forward_pass(x))
            return self.fc2(out)
        
    
resnet = res_net(res_block, [4, 16, 16*2], triplet=True, num_classes=5)

from torchsummary import summary
summary(resnet,(3,96,96))





Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 8, 96, 96]           216
├─BatchNorm2d: 1-2                       [-1, 8, 96, 96]           16
├─ReLU: 1-3                              [-1, 8, 96, 96]           --
├─Conv2d: 1-4                            [-1, 16, 96, 96]          1,152
├─BatchNorm2d: 1-5                       [-1, 16, 96, 96]          32
├─ReLU: 1-6                              [-1, 16, 96, 96]          --
├─Sequential: 1-7                        [-1, 16, 96, 96]          --
|    └─res_block: 2-1                    [-1, 16, 96, 96]          --
|    |    └─Conv2d: 3-1                  [-1, 16, 96, 96]          2,304
|    |    └─BatchNorm2d: 3-2             [-1, 16, 96, 96]          32
|    |    └─ReLU: 3-3                    [-1, 16, 96, 96]          --
|    |    └─Conv2d: 3-4                  [-1, 16, 96, 96]          2,304
|    |    └─BatchNorm2d: 3-5             [-1, 16, 96, 96]          32
|    

Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d: 1-1                            [-1, 8, 96, 96]           216
├─BatchNorm2d: 1-2                       [-1, 8, 96, 96]           16
├─ReLU: 1-3                              [-1, 8, 96, 96]           --
├─Conv2d: 1-4                            [-1, 16, 96, 96]          1,152
├─BatchNorm2d: 1-5                       [-1, 16, 96, 96]          32
├─ReLU: 1-6                              [-1, 16, 96, 96]          --
├─Sequential: 1-7                        [-1, 16, 96, 96]          --
|    └─res_block: 2-1                    [-1, 16, 96, 96]          --
|    |    └─Conv2d: 3-1                  [-1, 16, 96, 96]          2,304
|    |    └─BatchNorm2d: 3-2             [-1, 16, 96, 96]          32
|    |    └─ReLU: 3-3                    [-1, 16, 96, 96]          --
|    |    └─Conv2d: 3-4                  [-1, 16, 96, 96]          2,304
|    |    └─BatchNorm2d: 3-5             [-1, 16, 96, 96]          32
|    

In [17]:
# chkpoint = torch.load('/data/sathya/viper/weights/triplet_loss+bce/CE_loss-epoch=113-valid_acc=0.73.ckpt')



class final_net(nn.Module):
    def __init__(self, model):
        super(final_net, self).__init__()
        self.model = model
        self.accuracy = torchmetrics.Accuracy()
        self.triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
        self.fc2 = nn.Linear(64, 5)
        
    def forward(self, x):
        out = self.model(x)
        out = self.fc2(out)              #Without ReLU
#         out = self.fc2(nn.ReLU(out))
        return out


    
FN = final_net(resnet)
FN.load_state_dict(torch.load('./weights/triplet_loss/CE_loss-epoch=84-train_loss=0.00.ckpt')['state_dict'])




<All keys matched successfully>

In [18]:

from torchsummary import summary
summary(FN,(3,96,96))

Layer (type:depth-idx)                   Output Shape              Param #
├─res_net: 1-1                           [-1, 64]                  --
|    └─Conv2d: 2-1                       [-1, 8, 96, 96]           216
|    └─BatchNorm2d: 2-2                  [-1, 8, 96, 96]           16
|    └─ReLU: 2-3                         [-1, 8, 96, 96]           --
|    └─Conv2d: 2-4                       [-1, 16, 96, 96]          1,152
|    └─BatchNorm2d: 2-5                  [-1, 16, 96, 96]          32
|    └─ReLU: 2-6                         [-1, 16, 96, 96]          --
|    └─Sequential: 2-7                   [-1, 16, 96, 96]          --
|    |    └─res_block: 3-1               [-1, 16, 96, 96]          4,672
|    |    └─res_block: 3-2               [-1, 16, 96, 96]          4,672
|    |    └─res_block: 3-3               [-1, 16, 96, 96]          4,672
|    |    └─res_block: 3-4               [-1, 16, 96, 96]          4,672
|    └─Sequential: 2-8                   [-1, 32, 48, 48]          --

Layer (type:depth-idx)                   Output Shape              Param #
├─res_net: 1-1                           [-1, 64]                  --
|    └─Conv2d: 2-1                       [-1, 8, 96, 96]           216
|    └─BatchNorm2d: 2-2                  [-1, 8, 96, 96]           16
|    └─ReLU: 2-3                         [-1, 8, 96, 96]           --
|    └─Conv2d: 2-4                       [-1, 16, 96, 96]          1,152
|    └─BatchNorm2d: 2-5                  [-1, 16, 96, 96]          32
|    └─ReLU: 2-6                         [-1, 16, 96, 96]          --
|    └─Sequential: 2-7                   [-1, 16, 96, 96]          --
|    |    └─res_block: 3-1               [-1, 16, 96, 96]          4,672
|    |    └─res_block: 3-2               [-1, 16, 96, 96]          4,672
|    |    └─res_block: 3-3               [-1, 16, 96, 96]          4,672
|    |    └─res_block: 3-4               [-1, 16, 96, 96]          4,672
|    └─Sequential: 2-8                   [-1, 32, 48, 48]          --

In [16]:
class params:
    triplet = "False"
    num_classes = len(classes)
    clases = classes
    train_batch_size = 8*4 # Reduce if triplet is True, defualt : 8*4*4
    val_batch_size = 8*4*4

NameError: name 'classes' is not defined

In [14]:
params.triplet

'False'