In [29]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random
import time
import os
import torchvision.models as models
from torchvision.models.resnet import ResNet18_Weights
from torchvision.models.mobilenetv3 import mobilenet_v3_large,MobileNet_V3_Large_Weights
from torchvision.models.shufflenetv2 import shufflenet_v2_x1_0,ShuffleNet_V2_X1_0_Weights

In [27]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16,apply=True):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.apply = apply

        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)*x if self.apply else self.sigmoid(out) 
    
class SEBlock(nn.Module):
    def __init__(self, channel, reduction=16,apply=True):
        super(SEBlock, self).__init__()
        self.apply = apply
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x) if self.apply else y.expand_as(x)

class SuperImageStich(nn.Module):
    def __init__(self):
        super(SuperImageStich, self).__init__()
        #This is specific to 256 channel 90x90 samples 
        
    def forward(self, x):
        batch, c, h, w = x.size()
        
        final_tensor = torch.zeros((batch, 1, 16*90, 16*90))
        for b in range(batch):
            for i in range(16):
                for j in range(16):
                    start_row = i * 90
                    start_col = j * 90
                    img = x[b, i * 16 + j]
                    final_tensor[b, 0, start_row:start_row + 90, start_col:start_col + 90] = img

        return  final_tensor
    
class CustomResNet18(nn.Module):
    def __init__(self, pretrained=True):
        super(CustomResNet18, self).__init__()
        # Load the pre-trained ResNet-18 model
        if pretrained:
            self.resnet = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        else:
            self.resnet = models.resnet18(weights=None)        
        # Modify the first convolutional layer to accept 256 input channels
        original_conv1 = self.resnet.conv1
        self.resnet.conv1 = nn.Conv2d(256, original_conv1.out_channels, 
                                      kernel_size=original_conv1.kernel_size, 
                                      stride=original_conv1.stride, 
                                      padding=original_conv1.padding, 
                                      bias=original_conv1.bias)
        
        if pretrained:
            # Copy weights from the first layer of the pre-trained model
            with torch.no_grad():
                # Initialize the weights for the new conv1 layer
                self.resnet.conv1.weight[:, :3, :, :] = original_conv1.weight
                # Average the pre-trained weights and copy them to the remaining channels
                if original_conv1.weight.size(1) < 256:
                    for i in range(3, 256):
                        self.resnet.conv1.weight[:, i, :, :] = torch.mean(original_conv1.weight, dim=1)
        
        # The rest of the layers will remain the same as the pre-trained ResNet-18 model
        self.resnet.fc = Identity()

    def forward(self, x):
        return self.resnet(x)
class CustomMobileNetV3(nn.Module):
    def __init__(self, pretrained=True):
        super(CustomMobileNetV3, self).__init__()
        # Load the pre-trained MobileNetV3 model
        if pretrained:
            self.mobilenet = models.mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
        else:
            self.mobilenet = models.mobilenet_v3_large(weights=None)
        
        # Modify the first convolutional layer to accept 256 input channels
        original_conv1 = self.mobilenet.features[0][0]
        self.mobilenet.features[0][0] = nn.Conv2d(256, original_conv1.out_channels, 
                                                  kernel_size=original_conv1.kernel_size, 
                                                  stride=original_conv1.stride, 
                                                  padding=original_conv1.padding, 
                                                  bias=original_conv1.bias)
        
        if pretrained:
            # Copy weights from the first layer of the pre-trained model
            with torch.no_grad():
                # Initialize the weights for the new conv1 layer
                self.mobilenet.features[0][0].weight[:, :3, :, :] = original_conv1.weight
                # Average the pre-trained weights and copy them to the remaining channels
                if original_conv1.weight.size(1) < 256:
                    for i in range(3, 256):
                        self.mobilenet.features[0][0].weight[:, i, :, :] = torch.mean(original_conv1.weight, dim=1)
        
        # The classifier remains the same
        self.mobilenet.classifier[3] = Identity()

    def forward(self, x):
        return self.mobilenet(x)

class CustomShuffleNetV2(nn.Module):
    def __init__(self, pretrained=True):
        super(CustomShuffleNetV2, self).__init__()
        # Load the pre-trained ShuffleNetV2 model
        if pretrained:
            self.shufflenet = models.shufflenet_v2_x1_0(weights=ShuffleNet_V2_X1_0_Weights.DEFAULT)
        else:
            self.shufflenet = models.shufflenet_v2_x1_0(weights=None)
        
        # Modify the first convolutional layer to accept 256 input channels
        original_conv1 = self.shufflenet.conv1[0]
        self.shufflenet.conv1[0] = nn.Conv2d(256, original_conv1.out_channels, 
                                             kernel_size=original_conv1.kernel_size, 
                                             stride=original_conv1.stride, 
                                             padding=original_conv1.padding, 
                                             bias=original_conv1.bias)
        
        if pretrained:
            # Copy weights from the first layer of the pre-trained model
            with torch.no_grad():
                # Initialize the weights for the new conv1 layer
                self.shufflenet.conv1[0].weight[:, :3, :, :] = original_conv1.weight
                # Average the pre-trained weights and copy them to the remaining channels
                if original_conv1.weight.size(1) < 256:
                    for i in range(3, 256):
                        self.shufflenet.conv1[0].weight[:, i, :, :] = torch.mean(original_conv1.weight, dim=1)
        
        # The classifier remains the same
        self.shufflenet.fc = Identity()

    def forward(self, x):
        return self.shufflenet(x)



In [33]:
test = torch.randn(14, 256, 90, 90)
model = ChannelAttention(256)
print(model(test).shape)

torch.Size([14, 256, 90, 90])


In [34]:
test = torch.randn(2, 256, 90, 90)
model = SEBlock(256)
print(model(test).shape)

torch.Size([2, 256, 90, 90])


In [6]:
test = torch.randn(4, 256, 90, 90)
model = SuperImageStich()
output = model(test)
print(output.size())

torch.Size([4, 1, 1440, 1440])


In [17]:
test = torch.randn(4, 256, 90, 90)
model = CustomResNet18()
output = model(test)
print(output.size())

torch.Size([4, 512])


In [25]:
test = torch.randn(4, 256, 90, 90)
model = CustomMobileNetV3()
output = model(test)
print(output.size())

Downloading: "https://download.pytorch.org/models/mobilenet_v3_large-5c1a4163.pth" to /home/yatbaz_h@WMGDS.WMG.WARWICK.AC.UK/.cache/torch/hub/checkpoints/mobilenet_v3_large-5c1a4163.pth
100%|██████████| 21.1M/21.1M [00:00<00:00, 78.7MB/s]


torch.Size([4, 1280])


In [30]:
test = torch.randn(4, 256, 90, 90)
model = CustomShuffleNetV2()
output = model(test)
print(output.size())

Downloading: "https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth" to /home/yatbaz_h@WMGDS.WMG.WARWICK.AC.UK/.cache/torch/hub/checkpoints/shufflenetv2_x1-5666bf0f80.pth
100%|██████████| 8.79M/8.79M [00:00<00:00, 13.8MB/s]


torch.Size([4, 1024])


In [None]:
import N