In [4]:
import sys
import os

# Get the current directory
current_dir = os.getcwd()

# Get the parent directory
parent_dir = os.path.dirname(current_dir)

# Add the parent directory to the Python path
sys.path.append(parent_dir)

In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from functools import partial

In [70]:
from glasses.models.segmentation.unet import DownLayer, UpLayer, ConvBnAct, UNetDecoder, UNetEncoder
from glasses.models.base import Encoder

In [3]:
device = "cpu"

In [None]:
class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, atrous_rates):
        super(ASPP, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1)
        self.conv2 = nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rates[0], dilation=atrous_rates[0])
        self.conv3 = nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rates[1], dilation=atrous_rates[1])
        self.conv4 = nn.Conv2d(in_channels, out_channels, 3, padding=atrous_rates[2], dilation=atrous_rates[2])
        self.conv5 = nn.Conv2d(in_channels, out_channels, 1)
        self.bnorm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.final_conv = nn.Conv2d(out_channels * 5, out_channels, 1)
    
    def forward(self, x):
        out1 = self.relu(self.bnorm(self.conv1(x)))
        out2 = self.relu(self.bnorm(self.conv2(x)))
        out3 = self.relu(self.bnorm(self.conv3(x)))
        out4 = self.relu(self.bnorm(self.conv4(x)))
        out5 = F.adaptive_avg_pool2d(x, 1)
        out5 = F.interpolate(out5, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=True)
        out5 = self.relu(self.bnorm(self.conv5(out5)))
        out = torch.cat([out1, out2, out3, out4, out5], dim=1)
        out = self.relu(self.bnorm(self.final_conv(out)))

        return out

In [63]:
class AtrousBasicBlock(nn.Sequential):
    """Basic Block for UNet. It is composed by a double 3x3 conv."""

    def __init__(
        self,
        in_features: int,
        out_features: int,
        activation: nn.Module = partial(nn.ReLU, inplace=True),
        atrous_rates: list = [6, 12, 18],
        *args,
        **kwargs,
    ):
        super().__init__(
            ConvBnAct(
                in_features,
                out_features,
                kernel_size=3,
                activation=activation,
                *args,
                **kwargs,
            ),
            ConvBnAct(
                out_features,
                out_features,
                kernel_size=3,
                activation=activation,
                *args,
                **kwargs,
            ),
            ASPP(out_features, out_features, atrous_rates, *args, **kwargs),
        )

In [72]:
enc = UNetEncoder(in_channels=2, block=AtrousBasicBlock, atrous_rates=[6, 12, 18])

In [8]:
import torchinfo

In [73]:
device = "cpu"
inp = torch.Tensor(1, 2, 256, 256)
torchinfo.summary(enc, input_data=inp, device=device)

Layer (type:depth-idx)                             Output Shape              Param #
UNetEncoder                                        [1, 1024, 16, 16]         --
├─ModuleList: 1-1                                  --                        --
│    └─DownLayer: 2-1                              [1, 64, 256, 256]         --
│    │    └─Sequential: 3-1                        [1, 64, 256, 256]         178,048
│    └─DownLayer: 2-2                              [1, 128, 128, 128]        --
│    │    └─Sequential: 3-2                        [1, 128, 128, 128]        779,776
│    └─DownLayer: 2-3                              [1, 256, 64, 64]          --
│    │    └─Sequential: 3-3                        [1, 256, 64, 64]          3,116,032
│    └─DownLayer: 2-4                              [1, 512, 32, 32]          --
│    │    └─Sequential: 3-4                        [1, 512, 32, 32]          12,457,984
│    └─DownLayer: 2-5                              [1, 1024, 16, 16]         --
│    │    

In [74]:
enc.features
out = enc(inp)

In [75]:
out.shape

torch.Size([1, 1024, 16, 16])

In [76]:
for feats in enc.features:
    print(feats.shape)

torch.Size([1, 64, 256, 256])
torch.Size([1, 128, 128, 128])
torch.Size([1, 256, 64, 64])
torch.Size([1, 512, 32, 32])


In [77]:
features = enc.features
residuals = features[::-1]
dec = UNetDecoder(start_features=enc.widths[-1], lateral_widths=enc.features_widths[::-1])

In [78]:
residuals.extend([None] * (len(dec.layers) - len(residuals)))

In [79]:
out = dec(out, residuals)

In [83]:
from atrous_networks import unet

In [14]:
import importlib
importlib.reload(unet)

<module 'atrous_networks.unet' from '/home/prateek/ms_change_detection/atrous_networks/unet.py'>

In [89]:
model = unet.ASPPUNet(in_channels=2, n_classes=1)

In [90]:
torchinfo.summary(model, ((1, 1, 256, 256), (1, 1, 256, 256)), device=device)

Layer (type:depth-idx)                                  Output Shape              Param #
ASPPUNet                                                [1, 1, 256, 256]          --
├─UNetEncoder: 1-1                                      [1, 1024, 16, 16]         --
│    └─ModuleList: 2-1                                  --                        --
│    │    └─DownLayer: 3-1                              [1, 64, 256, 256]         141,056
│    │    └─DownLayer: 3-2                              [1, 128, 128, 128]        632,064
│    │    └─DownLayer: 3-3                              [1, 256, 64, 64]          2,525,696
│    │    └─DownLayer: 3-4                              [1, 512, 32, 32]          10,097,664
│    │    └─DownLayer: 3-5                              [1, 1024, 16, 16]         40,380,416
├─UNetDecoder: 1-2                                      [1, 32, 256, 256]         --
│    └─ModuleList: 2-2                                  --                        --
│    │    └─UpLayer: 3-6   

In [2]:

# Check if a GPU is available
if torch.cuda.is_available():
    # Get the free and total memory
    free_memory, total_memory = torch.cuda.mem_get_info()
    
    # Convert bytes to megabytes for easier reading
    free_memory_MB = free_memory / 1024**2
    total_memory_MB = total_memory / 1024**2
    
    print(f"Free Memory: {free_memory_MB:.2f} MB")
    print(f"Total Memory: {total_memory_MB:.2f} MB")
else:
    print("CUDA is not available.")

Free Memory: 7900.81 MB
Total Memory: 24217.31 MB




In [6]:
torch.cuda.empty_cache()

In [5]:
from atrous_networks import unet

In [17]:
model = unet.ASPPUNet(in_channels=2, n_classes=1, encoder_widths=[32, 64, 128, 256], decoder_widths=[256, 128, 64])

In [18]:
torchinfo.summary(model, ((1, 1, 256, 256), (1, 1, 256, 256)), device=device)

Layer (type:depth-idx)                                  Output Shape              Param #
ASPPUNet                                                [1, 1, 512, 512]          --
├─UNetEncoder: 1-1                                      [1, 256, 32, 32]          --
│    └─ModuleList: 2-1                                  --                        --
│    │    └─DownLayer: 3-1                              [1, 32, 256, 256]         35,712
│    │    └─DownLayer: 3-2                              [1, 64, 128, 128]         158,336
│    │    └─DownLayer: 3-3                              [1, 128, 64, 64]          632,064
│    │    └─DownLayer: 3-4                              [1, 256, 32, 32]          2,525,696
├─UNetDecoder: 1-2                                      [1, 32, 512, 512]         --
│    └─ModuleList: 2-2                                  --                        --
│    │    └─UpLayer: 3-5                                [1, 256, 64, 64]          1,737,984
│    │    └─UpLayer: 3-6        