In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader, RandomSampler
from PIL import Image

In [2]:
if torch.cuda.is_available():
    # Set the default device to CUDA
    device = torch.device('cuda')
    torch.set_default_device(device)
    print('Using CUDA for tensor operations')
    torch.cuda.empty_cache()
else:
    print('CUDA is not available. Using CPU for tensor operations')

Using CUDA for tensor operations


In [3]:
class CLEVRDataset(Dataset):
    def __init__(self, path, transform=None):
        self.video_paths = [os.path.join(path, dir_path) for dir_path in os.listdir(path) if dir_path.startswith('video')]
        self.transform = transform
        self._get_num_samples()
        
        if self.transform is None:
            self.transform = transforms.Compose([
                transforms.ToTensor(),  # Converts PIL Image to Tensor.
                transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Standard normalization for pre-trained models.
                                     std=[0.229, 0.224, 0.225])
            ])
        self.mask_transform = transforms.Compose([
            transforms.ToTensor()  # Only convert masks to tensor without normalization.
        ])

    def __len__(self):
        return self.num_samples

    def _get_num_samples(self):
        mask_sample_path = os.path.join(self.video_paths[0], 'mask.npy')
        mask_shapes = np.load(mask_sample_path).shape
        self.folder_num_samples = mask_shapes[0]
        self.num_samples = self.folder_num_samples * len(self.video_paths)

    def __getitem__(self, idx):
        image_index = idx % self.folder_num_samples
        folder_index = int(idx/self.folder_num_samples)
        
        img_name = os.path.join(self.video_paths[folder_index], f'image_{image_index}.png')
        mask_name = os.path.join(self.video_paths[folder_index], 'mask.npy')
        
        image = Image.open(img_name).convert("RGB")
        mask = np.load(mask_name)[image_index]
        
        # Create a one-hot encoded tensor
        num_objects = 49
        mask_tensor = torch.from_numpy(mask)
        height, width = mask_tensor.shape
        one_hot_mask = torch.zeros((num_objects, height, width), dtype=torch.float, device=device)
        one_hot_mask.scatter_(0, mask_tensor.unsqueeze(0).to(torch.int64).to(device), 1)

        image = self.transform(image).to(device)

        return image, one_hot_mask


In [4]:
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # Encoder
        self.down1 = self.contract_block(3, 64, 7, 3)
        self.down2 = self.contract_block(64, 128, 3, 1)
        self.down3 = self.contract_block(128, 256, 3, 1)
        self.down4 = self.contract_block(256, 512, 3, 1)

        # Decoder
        self.up3 = self.expand_block(512, 256, 3, 1)
        self.up2 = self.expand_block(256, 128, 3, 1)
        self.up1 = self.expand_block(128, 64, 3, 1)
        self.final_up = nn.ConvTranspose2d(64, 49, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.final = nn.Conv2d(49, 49, kernel_size=1)  # Change from 1 to 48

    def forward(self, x):
        # Encoder
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)

        # Decoder
        x = self.up3(x4)
        x = self.up2(x + x3)
        x = self.up1(x + x2)
        x = self.final_up(x + x1)
        x = self.final(x)

        return x

    def contract_block(self, in_channels, out_channels, kernel_size, padding):
        contract = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        return contract

    def expand_block(self, in_channels, out_channels, kernel_size, padding):
        expand = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels),
            nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )
        return expand

In [5]:
# Hyperparameters
num_epochs = 25
batch_size = 64
learning_rate = 0.001

# Dataset and DataLoader
dataset = CLEVRDataset(path='dataset/train/')

# Assuming 'dataset' is already defined
generator = torch.Generator(device='cuda')
sampler = RandomSampler(dataset, generator=generator)
dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, shuffle=(sampler is None))

# Model, Loss, and Optimizer
model = UNet().to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training Loop
for epoch in range(num_epochs):
    batch = 0
    total_loss = 0
    batch_loss = 0
    for images, masks in dataloader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        batch_loss += loss.item()
        
        batch += 1
        
        if (batch + 1) % 100 == 0:
            print(f'[{epoch + 1}, {batch + 1:5d}] Batch Loss: {batch_loss / 100:.8f}')
            batch_loss = 0.0
            
    print(f'Epoch {epoch+1}, Total Loss: {total_loss}')
    if (epoch+1)%5 == 0:
        checkpoint_path = f"Unet_checkpoint_epoch_{epoch}.pt"
        torch.save(model.state_dict(), checkpoint_path)

  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


[1,   100] Batch Loss: 0.14566248
[1,   200] Batch Loss: 0.01025905
[1,   300] Batch Loss: 0.00758162
Epoch 1, Total Loss: 16.68228873424232
[2,   100] Batch Loss: 0.00599147
[2,   200] Batch Loss: 0.00558053
[2,   300] Batch Loss: 0.00458001
Epoch 2, Total Loss: 1.8149344213306904
[3,   100] Batch Loss: 0.00414734
[3,   200] Batch Loss: 0.00370698
[3,   300] Batch Loss: 0.00368675
Epoch 3, Total Loss: 1.3038259758614004
[4,   100] Batch Loss: 0.00293263
[4,   200] Batch Loss: 0.00279254
[4,   300] Batch Loss: 0.00258937
Epoch 4, Total Loss: 0.9451475997921079
[5,   100] Batch Loss: 0.00229698
[5,   200] Batch Loss: 0.00212843
[5,   300] Batch Loss: 0.00194772
Epoch 5, Total Loss: 0.7162891470361501
[6,   100] Batch Loss: 0.00150876
[6,   200] Batch Loss: 0.00126956
[6,   300] Batch Loss: 0.00104856
Epoch 6, Total Loss: 0.42119537241524085
[7,   100] Batch Loss: 0.00080216
[7,   200] Batch Loss: 0.00070042
[7,   300] Batch Loss: 0.00063050
Epoch 7, Total Loss: 0.24150251364335418
[8,  

In [6]:
# Training Loop
for epoch in range(num_epochs, 2*num_epochs):
    batch = 0
    total_loss = 0
    batch_loss = 0
    for images, masks in dataloader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        batch_loss += loss.item()
        
        batch += 1
        
        if (batch + 1) % 100 == 0:
            print(f'[{epoch + 1}, {batch + 1:5d}] Batch Loss: {batch_loss / 100:.8f}')
            batch_loss = 0.0
            
    print(f'Epoch {epoch+1}, Total Loss: {total_loss}')
    if (epoch+1)%5 == 0:
        checkpoint_path = f"Unet_checkpoint_epoch_{epoch}.pt"
        torch.save(model.state_dict(), checkpoint_path)

[26,   100] Batch Loss: 0.00013680
[26,   200] Batch Loss: 0.00014168
[26,   300] Batch Loss: 0.00014591
Epoch 26, Total Loss: 0.0510188798289164
[27,   100] Batch Loss: 0.00036189
[27,   200] Batch Loss: 0.00023557
[27,   300] Batch Loss: 0.00018590
Epoch 27, Total Loss: 0.08608283301873598
[28,   100] Batch Loss: 0.00015653
[28,   200] Batch Loss: 0.00015898
[28,   300] Batch Loss: 0.00015448
Epoch 28, Total Loss: 0.053673918460845016
[29,   100] Batch Loss: 0.00014426
[29,   200] Batch Loss: 0.00014535
[29,   300] Batch Loss: 0.00013715
Epoch 29, Total Loss: 0.04875242297566729
[30,   100] Batch Loss: 0.00012515
[30,   200] Batch Loss: 0.00012648
[30,   300] Batch Loss: 0.00013156
Epoch 30, Total Loss: 0.04430344058346236
[31,   100] Batch Loss: 0.00011842
[31,   200] Batch Loss: 0.00012007
[31,   300] Batch Loss: 0.00012235
Epoch 31, Total Loss: 0.04170447447540937
[32,   100] Batch Loss: 0.00011385
[32,   200] Batch Loss: 0.00012223
[32,   300] Batch Loss: 0.00012201
Epoch 32, Tot

In [7]:
# Training Loop
for epoch in range(2*num_epochs, 3*num_epochs):
    batch = 0
    total_loss = 0
    batch_loss = 0
    for images, masks in dataloader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        batch_loss += loss.item()
        
        batch += 1
        
        if (batch + 1) % 100 == 0:
            print(f'[{epoch + 1}, {batch + 1:5d}] Batch Loss: {batch_loss / 100:.8f}')
            batch_loss = 0.0
            
    print(f'Epoch {epoch+1}, Total Loss: {total_loss}')
    if (epoch+1)%5 == 0:
        checkpoint_path = f"Unet_checkpoint_epoch_{epoch}.pt"
        torch.save(model.state_dict(), checkpoint_path)

[51,   100] Batch Loss: 0.00008873
[51,   200] Batch Loss: 0.00023209
[51,   300] Batch Loss: 0.00020863
Epoch 51, Total Loss: 0.059385071726865135
[52,   100] Batch Loss: 0.00010968
[52,   200] Batch Loss: 0.00010077
[52,   300] Batch Loss: 0.00009586
Epoch 52, Total Loss: 0.034787286844220944
[53,   100] Batch Loss: 0.00007949
[53,   200] Batch Loss: 0.00007960
Epoch 53, Total Loss: 0.027653157674649265
[54,   100] Batch Loss: 0.00007219
[54,   200] Batch Loss: 0.00007236
[54,   300] Batch Loss: 0.00007830
Epoch 54, Total Loss: 0.025752232009836007
[55,   100] Batch Loss: 0.00006736
[55,   200] Batch Loss: 0.00006957
[55,   300] Batch Loss: 0.00007130
Epoch 55, Total Loss: 0.024036996401264332
[56,   100] Batch Loss: 0.00006458
[56,   200] Batch Loss: 0.00006646
[56,   300] Batch Loss: 0.00006795
Epoch 56, Total Loss: 0.022954800544539467
[57,   100] Batch Loss: 0.00006182
[57,   200] Batch Loss: 0.00006508
[57,   300] Batch Loss: 0.00006667
Epoch 57, Total Loss: 0.022395457213860936

In [8]:
# Training Loop
for epoch in range(3*num_epochs, 4*num_epochs):
    batch = 0
    total_loss = 0
    batch_loss = 0
    for images, masks in dataloader:
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        batch_loss += loss.item()
        
        batch += 1
        
        if (batch + 1) % 100 == 0:
            print(f'[{epoch + 1}, {batch + 1:5d}] Batch Loss: {batch_loss / 100:.8f}')
            batch_loss = 0.0
            
    print(f'Epoch {epoch+1}, Total Loss: {total_loss}')
    if (epoch+1)%5 == 0:
        checkpoint_path = f"Unet_checkpoint_epoch_{epoch}.pt"
        torch.save(model.state_dict(), checkpoint_path)

[76,   100] Batch Loss: 0.00004523
[76,   200] Batch Loss: 0.00004824
[76,   300] Batch Loss: 0.00004951
Epoch 76, Total Loss: 0.01655479281907901
[77,   100] Batch Loss: 0.00004516
[77,   200] Batch Loss: 0.00004728
[77,   300] Batch Loss: 0.00004798
Epoch 77, Total Loss: 0.016305480243318016
[78,   100] Batch Loss: 0.00004425
[78,   200] Batch Loss: 0.00004574
[78,   300] Batch Loss: 0.00004803
Epoch 78, Total Loss: 0.015993970846466254
[79,   100] Batch Loss: 0.00004355
[79,   200] Batch Loss: 0.00004525
[79,   300] Batch Loss: 0.00004676
Epoch 79, Total Loss: 0.015698602594056865
[80,   100] Batch Loss: 0.00004308
[80,   200] Batch Loss: 0.00004531
[80,   300] Batch Loss: 0.00004671
Epoch 80, Total Loss: 0.01559824840296642
[81,   100] Batch Loss: 0.00004339
[81,   200] Batch Loss: 0.00004550
[81,   300] Batch Loss: 0.00004633
Epoch 81, Total Loss: 0.015680738448281772
[82,   100] Batch Loss: 0.00004224
[82,   200] Batch Loss: 0.00004408
[82,   300] Batch Loss: 0.00004559
Epoch 82,