In [None]:
import torch
import torch.nn as nn
from IPython.display import clear_output
from typing import Dict
from torch.utils.tensorboard.summary import hparams


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding="same", stride=1, r=False):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride)
        self.bn = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.conv3 = nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=padding, stride=stride) if r == True else nn.Identity()
        
    def forward(self, x):
        x1 = self.conv(x)
        x1 = self.bn(x1)
        x1 = self.relu(x1)
        x1 = self.conv2(x1)
        x1 = self.bn2(x1)
        x2 = self.conv3(x)
        return x1+x2
    
class BottleNeck(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
        super().__init__()
        self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size=kernel_size, padding=1, stride=stride)
        self.relu = nn.ReLU()
       
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels ,skip_channels, scale=1, kernel_size=2, padding=1, stride=2,last=True):
        super().__init__()
        self.pixel_shuffle = nn.PixelShuffle(scale) if scale > 1 else nn.Identity()
        self.conv = nn.ConvTranspose1d(skip_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride)
        self.bn = nn.BatchNorm1d(out_channels)

        self.relu = nn.ReLU()
        self.relu2 = nn.ReLU()  if last == True else nn.Identity()

    def forward(self, x,skip):
       
            # Ensure x is correctly sized to match skip
    
        x = torch.cat([x, skip], dim=1)
        # print("Pre-BN size:", x.shape)  # This should show the number of channels
        # print(f"Channel size before ConvTranspose1d: {x.shape[1]}")  # Should be 1536

        x = self.conv(x)
        x = self.bn(x)

        x = self.relu(x)
        x = self.relu2(x)
      
        return x


class Runet(nn.Module):

    def __init__(self, **kwargs):
        super(Runet, self).__init__(**kwargs)

        self.down1 = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=7, padding="same", stride=1),
            nn.BatchNorm1d(64),
            nn.ReLU())

        self.down2 = nn.Sequential(
            nn.MaxPool1d(2),
            DownBlock(64, 64, r=False),
            DownBlock(64, 64, r=False),
            DownBlock(64, 64, r=False),
            DownBlock(64, 128, r=True))
        self.down3 = nn.Sequential(
            nn.MaxPool1d(2, stride=2),
            DownBlock(128, 128, r=False),
            DownBlock(128, 128, r=False),
            DownBlock(128, 128, r=False),
            DownBlock(128, 256, r=True))

        self.down4 = nn.Sequential(
            nn.MaxPool1d(2),
            DownBlock(256, 256, r=False),
            DownBlock(256, 256, r=False),
            DownBlock(256, 256, r=False),
            DownBlock(256, 256, r=False),
            DownBlock(256, 256, r=False),
            DownBlock(256, 512, r=True))

        self.down5 = nn.Sequential(
            nn.MaxPool1d(2),
            DownBlock(512, 512, r=False),
            DownBlock(512, 512, r=False),
            nn.BatchNorm1d(512),
            nn.ReLU())

        self.BottleNeck = nn.Sequential(
            BottleNeck(512, 1024),
            BottleNeck(1024, 512))

        self.up1 = UpBlock(512, 512, 1024, scale=1)
        self.up2 = UpBlock(512, 384, 1024, scale=1)
        self.up3 = UpBlock(384, 256, 640, scale=1)
        self.up4 = UpBlock(256, 96, 384, scale=1)

        self.up5 = UpBlock(96, 99, 160, scale=1, last=False)
        self.conv = nn.Conv1d(99, 1, kernel_size=1, padding=0, stride=1)
        self.avg_pool = nn.AvgPool1d(2, stride=2)


    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        x6 = self.BottleNeck(x5)

        x7 = self.up1(x6, x5)
        # print("x4:", x4.shape)

        # print("x7:", x7.shape)

        x8 = self.up2(x7, x4)
        x9 = self.up3(x8, x3)
        x10 = self.up4(x9, x2)
        x11 = self.up5(x10, x1)
        result = self.conv(x11)
        output = self.avg_pool(result)
        return output
    