In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import os

In [4]:
# tensors = []
# directory = 'pth_data'
# for filename in os.listdir(directory):
#     tensor = torch.load(os.path.join(directory, filename))
#     tensors.append(tensor)

tensor = torch.load('HUC8_CA_PFAS_GTruth_Summa2.pth')
tensor.shape

torch.Size([965866, 1, 10, 10])

In [None]:
concat_tensor = torch.cat(tensors, dim=1)

torch.Size([965866, 1, 10, 10])

In [None]:
# Define the UNet model
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        def conv_block(in_c, out_c):
            block = nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
            )
            return block

        def down_block(in_c, out_c):
            block = nn.Sequential(
                nn.MaxPool2d(2),
                conv_block(in_c, out_c)
            )
            return block

        def up_block(in_c, out_c):
            block = nn.Sequential(
                nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2),
                conv_block(out_c, out_c)
            )
            return block

        # Encoder
        self.enc1 = conv_block(in_channels, 64)
        self.enc2 = down_block(64, 128)
        self.enc3 = down_block(128, 256)
        self.enc4 = down_block(256, 512)

        # Bottleneck
        self.bottleneck = conv_block(512, 1024)

        # Decoder
        self.dec4 = up_block(1024, 512)
        self.dec3 = up_block(512, 256)
        self.dec2 = up_block(256, 128)
        self.dec1 = up_block(128, 64)

        # Final layer
        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder path
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)

        # Bottleneck
        b = self.bottleneck(e4)

        # Decoder path
        d4 = self.dec4(b)
        d3 = self.dec3(d4 + e3)  # skip connection
        d2 = self.dec2(d3 + e2)
        d1 = self.dec1(d2 + e1)

        # Final output layer
        out = self.final(d1)
        return out

In [None]:
# Instantiate the model, define the loss function and the optimizer
model = UNet(in_channels=1, out_channels=1)
criterion = nn.MSELoss()  # Mean Squared Error for grayscale output
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Dummy input (batch of grayscale images, e.g., 1 sample of 1x128x128 image)
input_image = torch.randn((1, 1, 128, 128))  # 1 batch, 1 channel, 128x128 resolution
output = model(input_image)

print("Output shape:", output.shape)  # Should be (1, 1, 128, 128)