In [1]:
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

In [2]:
relu = nn.ReLU()

class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        x = relu(x)
        x = self.conv2(x)
        x = relu(x)
        return x


In [3]:
class DecoderConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        # output_size = (input_size - 1) * stride + kernel_size - 2 * padding
        self.up_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2) 
        self.conv = DoubleConv(out_channels, out_channels)

    def forward(self, x, x1):
        y1 = self.up_conv(x)
        y2 = torch.cat((y1, x1))
        y = self.conv(y2)
        return y

In [4]:
class UNet(nn.Module):
    def __init__(self, input_channels=1):
        super().__init__()

        # non trainable layers
        self.pool = nn.MaxPool2d(kernel_size=2)
        
        # trainable layers
        ## encoder 
        self.en_conv1 = DoubleConv(input_channels,32)
        self.en_conv2 = DoubleConv(32,64)
        self.en_conv3 = DoubleConv(64,128)
        self.en_conv4 = DoubleConv(128,256)
        
        ## bottleneck
        self.conv = DoubleConv(256,512)

        ## decoder
        self.de_conv4 = DecoderConv(512,256)
        self.de_conv3 = DecoderConv(256,128)
        self.de_conv2 = DecoderConv(128,64)
        self.de_conv1 = DecoderConv(64,32)
        
        ## reconstruct layer
        self.reconstruct = nn.Conv2d(32, 3, kernel_size=3)

    def forward(self, x):
        # encoder inference 
        e1 = self.en_conv1(x)
        p1 = self.pool(e1)
        
        e2 = self.en_conv2(p1)
        p2 = self.pool(e2)
        
        e3 = self.en_conv3(p2)
        p3 = self.pool(e3)
        
        e4 = self.en_conv3(p3)
        p4 = self.pool(e4)
        
        # bottleneck inference
        b = self.conv(e4)
        
        # decoder inference 
        d4 = self.de_conv4(b,e4)
        d3 = self.de_conv3(d4,e3)
        d2 = self.de_conv2(d3,e2)
        d1 = self.de_conv1(d2,e1)

        # reconstruct image
        colored_image = self.reconstruct(d1)


        
        return colored_image

In [5]:
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    # transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

In [7]:
train_data = datasets.CIFAR100(
    root='data',
    train=True,
    download=True,
    transform=transforms.ToTensor()
)

test_data = datasets.CIFAR100(
    root='data',
    train=False,
    download=True,
    transform=transforms.ToTensor()
)


Files already downloaded and verified
Files already downloaded and verified


In [8]:
train_data.transform = transform
test_data.transform = transform

In [9]:
import torch.utils.data as data_utils

# Create data loaders
batch_size = 128
train_loader = data_utils.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = data_utils.DataLoader(test_data, batch_size=batch_size)

In [11]:
# import matplotlib.pyplot as plt
# import numpy as np

# # Get a batch of images
# data_iter = iter(train_loader)
# images, labels = data_iter.next()

# # Plot the images
# fig = plt.figure(figsize=(10, 10))
# for i in range(25):
#     ax = fig.add_subplot(5, 5, i+1, xticks=[], yticks=[])
#     img = images[i] / 2 + 0.5 # Unnormalize the image
#     img = np.transpose(img.numpy(), (1, 2, 0))
#     ax.imshow(img)
#     ax.set_title(str(labels[i].item()))
plt.show()

In [12]:
# Initialize the network and optimizer
net = UNet()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# Define the loss function
criterion = nn.CrossEntropyLoss()

# Train the network
num_epochs = 10
for epoch in range(num_epochs):
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader, 0):
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = net(inputs)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Update the running loss
        running_loss += loss.item()

        # Print the loss every 1000 mini-batches
        if (i+1) % 1000 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss/1000:.4f}')
            running_loss = 0.0

print('Finished Training')

# Test the network
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        # Forward pass
        outputs = net(inputs)

        # Get the predicted class
        _, predicted = torch.max(outputs.data, 1)

        # Compute the accuracy
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {correct/total:.4f}')

AttributeError: cannot assign module before Module.__init__() call