In [7]:
import pandas as pd
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchmetrics.segmentation import DiceScore
import sys
sys.path.append("..")

from models.unet import GameUNet
from models.early_stopping import EarlyStopping

import matplotlib.pyplot as plt

In [2]:
train_ds = torch.load("../../dataset/train.pt", weights_only=False)
test_ds = torch.load("../../dataset/test.pt", weights_only=False)
val_ds = torch.load("../../dataset/val.pt", weights_only=False)

In [3]:
from torch.utils.data import TensorDataset

train_ds = TensorDataset(*train_ds[:1000])
test_ds = TensorDataset(*test_ds[:100])
val_ds = TensorDataset(*val_ds[:200])

In [4]:
from torch.utils.data import DataLoader

batch_size = 32
num_epochs = 20

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

es = EarlyStopping(patience=10)

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GameUNet(n_classes=3, n_actions=4).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss().to(device)
dice = DiceScore(num_classes=3, average='macro', input_format='index').to(device)


In [None]:
es = EarlyStopping(patience=10)

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for x_batch, a_batch, y_batch in train_loader:
        x_batch = x_batch.to(device)
        a_batch = a_batch.to(device)
        y_batch = y_batch.to(device)

        y_pred = model(x_batch, a_batch)
        loss = criterion(y_pred, y_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x_batch.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}] - Loss: {epoch_loss:.4f}")

    model.eval()
    val_loss = 0.0
    val_dice = 0.0
    with torch.no_grad():
        for x_batch, a_batch, y_batch in val_loader:
            x_batch = x_batch.to(device)
            a_batch = a_batch.to(device)
            y_batch = y_batch.to(device)

            y_pred = model(x_batch, a_batch)
            val_loss += criterion(y_pred, y_batch).item() * x_batch.size(0)

            y_batch_onehot = F.one_hot(y_batch.long(), num_classes=3)
            y_batch_onehot = y_batch_onehot.permute(0, 3, 1, 2).float()

            val_dice += dice(y_pred, y_batch_onehot) * x_batch.size(0)

    val_loss /= len(val_loader.dataset)
    val_dice /= len(val_loader.dataset)
    print(f"Validation Loss: {val_loss: .4f} - DiceScore: {val_dice: .4f}")

    es(val_loss, model)
    if es.early_stop:
        print("⏹️ Early stopping")
        break


Epoch [1/20] - Loss: 0.0813


RuntimeError: one_hot is only applicable to index tensor of type LongTensor.

In [None]:
model(torch.zeros((1, 3, 16, 16)), torch.zeros(1, 4)).shape

torch.Size([1, 3, 16, 16])

In [None]:
import torchsummary

torchsummary.summary(model, input_size=[(3, 16, 16), [4]])

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 16, 16]             864
       BatchNorm2d-2           [-1, 32, 16, 16]              64
              ReLU-3           [-1, 32, 16, 16]               0
            Conv2d-4           [-1, 32, 16, 16]           9,216
       BatchNorm2d-5           [-1, 32, 16, 16]              64
              ReLU-6           [-1, 32, 16, 16]               0
        DoubleConv-7           [-1, 32, 16, 16]               0
         MaxPool2d-8             [-1, 32, 8, 8]               0
            Conv2d-9             [-1, 64, 8, 8]          18,432
      BatchNorm2d-10             [-1, 64, 8, 8]             128
             ReLU-11             [-1, 64, 8, 8]               0
           Conv2d-12             [-1, 64, 8, 8]          36,864
      BatchNorm2d-13             [-1, 64, 8, 8]             128
             ReLU-14             [-1, 6

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.