In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import cv2
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import os
from torchvision.transforms import ToTensor
from torchvision import transforms


In [23]:
class SatelliteDataset(Dataset):
    def __init__(self, image_dir, images_name, targets_name, transform):
        # self.image_dir = image_dir
        self.transform = transform
        self.images_name = images_name
        self.targets_name = targets_name
        self.image_dir = image_dir
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    # def __getitem__(self, idx):
    #     image_path = os.path.join(self.image_dir, self.images[idx])
    #     mask_path = os.path.join(self.mask_dir, self.images[idx].replace('image', 'mask'))
    #     image = Image.open(image_path).convert("RGB")
    #     mask = Image.open(mask_path).convert("L")
    #     if self.transform:
    #         image = self.transform(image)
    #         mask = self.transform(mask)
    #     return image, mask

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.images_name[idx])
        mask_path = os.path.join(self.image_dir, self.targets_name[idx])

        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)
        return image, mask

In [24]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = DoubleConv(64, 128)
        self.down2 = DoubleConv(128, 256)
        self.down3 = DoubleConv(256, 512)
        self.up1 = DoubleConv(512, 256)
        self.up2 = DoubleConv(256, 128)
        self.up3 = DoubleConv(128, 64)
        self.outc = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        x1 = F.max_pool2d(self.inc(x), 2)
        x2 = F.max_pool2d(self.down1(x1), 2)
        x3 = F.max_pool2d(self.down2(x2), 2)
        x4 = F.max_pool2d(self.down3(x3), 2)
        x = F.interpolate(self.up1(x4), scale_factor=2, mode='bilinear', align_corners=True)
        x = F.interpolate(self.up2(x + x3), scale_factor=2, mode='bilinear', align_corners=True)
        x = F.interpolate(self.up3(x + x2), scale_factor=2, mode='bilinear', align_corners=True)
        logits = self.outc(x + x1)
        return logits

In [31]:
#base_dir = "/Users/yijiewang/Documents/WINTER 2024/APS Project/dataset"
base_dir = "/u/d/wang3812/Documents/APSProject/cropped_dataset"
DATA_DIR = os.path.join(base_dir, "train")
COLOR_CODES = pd.read_csv(os.path.join(base_dir, "class_dict.csv"))
train_annotations = pd.read_csv(os.path.join(base_dir, "cropped_proportion_train.csv"))
valid_annotations = pd.read_csv(os.path.join(base_dir, "cropped_proportion_valid.csv"))
test_annotations = pd.read_csv(os.path.join(base_dir, "cropped_proportion_test.csv"))
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((128, 128)),
])

In [33]:
images_name = sorted(train_annotations['sat_image_path'])
images_name = images_name[:1000]
labels_name = sorted(train_annotations['mask_path'])
labels_name = labels_name[:1000]
# img_dir = "/Users/yijiewang/Documents/WINTER 2024/APS Project/dataset"
img_dir = "/u/d/wang3812/Documents/APSProject/cropped_dataset"

In [34]:
# num_classes = len(label_map.keys())
lr = 0.01
batch_size = 4
# model = UNet(n_class=num_classes)
train_dataset = SatelliteDataset(img_dir, images_name, labels_name, transform)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
model = UNet(n_channels=3, n_classes=6)  # for binary classification
optimizer = torch.optim.AdamW(model.parameters(), lr)
criterion = torch.nn.CrossEntropyLoss()
num_epochs = 10
device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [35]:
num_epochs = 3  # Just as an example, you might need more epochs

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for images, masks in train_dataloader:
        optimizer.zero_grad()
        predictions = model(images)
        loss = criterion(predictions, masks.squeeze(1).long())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

RuntimeError: size mismatch (got input: [4, 6, 64, 64] , target: [4, 128, 128]