In [None]:
import torch
from ddpm import DDPM
import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
from old_unet import UNet
from data import SequencesDataset
from train import train
import torchvision.transforms as transforms
import os

In [None]:
EPOCHS = 1000

T = 1000
input_channels = 3
context_length = 4
actions_count = 5
batch_size = 32
num_workers = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
# For Mac OS
if torch.backends.mps.is_available():
    device = "mps"

# ROOT_PATH = "../snake_agent/q_learning"
ROOT_PATH = "./"
def local_path(path):
    return os.path.join(ROOT_PATH, path)

In [3]:
ddpm = DDPM(
    T = T,
    eps_model=UNet(
        in_channels=input_channels * (context_length + 1),
        out_channels=3,
        T=T+1,
        actions_count=actions_count,
        seq_length=context_length
    ),
    context_length=context_length,
    device=device
)

In [4]:
transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5,.5,.5), (.5,.5,.5))
])

dataset = SequencesDataset(
    images_dir=local_path("snapshots"),
    actions_path=local_path("actions"),
    seq_length=context_length,
    transform=transform_to_tensor
)

total_size = len(dataset)
train_size = int(0.8 * total_size)  # 80% for training
valid_size = total_size - train_size  # 20% for validation

# Split the dataset
train_dataset, val_dataset = random_split(
    dataset, 
    [train_size, valid_size]
)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

In [5]:
_, val_losses = train(
    model=ddpm,
    optimizer=torch.optim.Adam(params=ddpm.parameters(), lr=2e-4),
    epochs=EPOCHS,
    device=device,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    save_every_epoch=100
)

  0%|          | 0/458 [00:13<?, ?it/s]


AttributeError: 'list' object has no attribute 'to'