In [1]:
import torch
import PIL
import os
import torchmetrics
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import v2

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:


class CustomDataset(Dataset):
    def __init__(self, path, smallResolution, largeResolution):
        super().__init__()
        self.image_filenames = sorted(os.listdir(path))

        self.data = []
        for filename in self.image_filenames:
            imagePath = os.path.join(path, filename)

            transform = v2.Compose(
                [v2.PILToTensor(), v2.ToDtype(torch.float32, scale=True)]
            )

            im = PIL.Image.open(imagePath)
            X = im.resize(smallResolution)
            X = transform(X)
            y = im.resize(largeResolution)
            y = transform(y)

            self.data.append((X, y))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# Building a vision transformer

In [3]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size, patch_size, in_channels, embed_dim):
        super().__init__()
        self.patch_size = patch_size
        self.img_size = img_size
        self.grid_size = (img_size[0] // patch_size, img_size[1] // patch_size)
        num_patches = self.grid_size[0] * self.grid_size[1]
        
        self.proj = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches, embed_dim))
        
    
    def forward(self, x:torch.Tensor):
        B=x.size(0)
        x= self.proj(x) # (B, embed_dim, H//P, W//P)
        x=x.flatten(2).transpose(1,2) # (B, N, embed_dim)
        x = x + self.pos_embed
        return x, self.grid_size

In [4]:
class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, drop_rate):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=hidden_features),
            nn.GELU(),
            nn.Dropout(drop_rate),
            nn.Linear(in_features=hidden_features, out_features=in_features),
            nn.Dropout(drop_rate),
        )

    def forward(self, x: torch.Tensor):
        return self.layers(x)

In [5]:
class TransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, drop_rate):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=drop_rate, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(in_features=embed_dim, hidden_features=mlp_dim, drop_rate=drop_rate)
        
    def forward(self, x):
        x = x+ self.attention(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.mlp(self.norm2(x))
        return x
        

In [6]:
class Decoder(nn.Module):
    def __init__(self, embed_dim, out_channels, upsample_factor):
        super().__init__()
        self.decode = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(
                embed_dim, out_channels,
                kernel_size=upsample_factor,
                stride=upsample_factor
            )
        )

    def forward(self, x, grid_size):
        B, N, E = x.shape
        H, W = grid_size  # grid size from patch embedding
        x = x.transpose(1, 2).reshape(B, E, H, W)  # (B, E, H, W)
        x = self.decode(x)  # (B, 3, H*scale, W*scale)
        return x

In [7]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size,
        patch_size,
        embed_dim,
        depth,
        num_heads,
        mlp_dim,
        drop_rate=0.2,
        in_channels=3,
        upsample_factor=2,
    ):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=upsample_factor, mode='bilinear', align_corners=False)
        
        upsampled_size = (img_size[0] * upsample_factor, img_size[1] * upsample_factor)
        
        self.patchEmbed = PatchEmbedding(
            img_size=upsampled_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embed_dim=embed_dim,
        )
        self.encoder = nn.Sequential(
            *[
                TransformerEncoderLayer(
                    embed_dim=embed_dim,
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    drop_rate=drop_rate,
                )
                for _ in range(depth)
            ]
        )
        self.norm = nn.LayerNorm(embed_dim)
        self.decoder = Decoder(embed_dim=embed_dim, out_channels=3, upsample_factor=patch_size)

    def forward(self, x):
        x = self.upsample(x)
        x, grid_size = self.patchEmbed(x)  # (B, N, E)
        x = self.encoder(x)  # (B, N, E)
        x = self.norm(x)  # (B, N, E)
        x = self.decoder(x, grid_size)  # (B, 3, H, W)
        return x

In [None]:
model = VisionTransformer(img_size=(360, 640), patch_size=20,
                          embed_dim=512, depth=4, num_heads=8, mlp_dim=2048).to(device)

image = torch.randn(1, 3, 360, 640).to(device)
output = model(image)
print(output.shape)  # Expected: (1, 3, 720, 1280)

torch.Size([1, 3, 720, 1280])


In [9]:
optimizer = torch.optim.Adam(params = model.parameters(), lr=0.00001)
# optimizer = torch.optim.Adam(params = model_01.parameters(), lr=0.000001)
cost_fn = torch.nn.SmoothL1Loss()

In [None]:
trainDataset = CustomDataset(
    "/home/radekbys/Code/trasnsformer_upscaling/dataset/DIV2K_train_HR/DIV2K_train_HR",
    (640, 360),
    (1280, 720),
)

testDataset = CustomDataset(
    "/home/radekbys/Code/trasnsformer_upscaling/dataset/DIV2K_valid_HR/DIV2K_valid_HR",
    (640, 360),
    (1280, 720),
)

In [None]:
train_dataloader = DataLoader(trainDataset, 4, shuffle=True)
test_dataloader = DataLoader(testDataset, 4, shuffle=False)
len(train_dataloader), len(test_dataloader)

(200, 25)

In [None]:
epochs = 100
metric = torchmetrics.MeanSquaredError().to("cuda")

for epoch in range(epochs):
    losses = []
    for i, (X, y) in enumerate(train_dataloader):
        model.train()
        X = X.to("cuda")
        y = y.to("cuda")
        output = model(X)
        loss = cost_fn(input=output, target=y)
        losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (i + 1) % 200 == 0:
            model.eval()
            with torch.inference_mode():

                test_losses = []
                mserrors = []
                for j in range(100):
                    # idx = random.random()
                    # idx = int(idx * 100)
                    idx = j
                    X = testDataset[idx][0].unsqueeze(0).to("cuda")
                    y = testDataset[idx][1].unsqueeze(0).to("cuda")
                    test_output = model(X)
                    test_loss = cost_fn(input=test_output, target=y)
                    test_losses.append(test_loss.item())
                    mse = metric(test_output.contiguous(), y.contiguous())
                    mserrors.append(mse)

                train_loss = sum(losses) / len(losses)
                test_loss = sum(test_losses) / len(test_losses)
                mse = sum(mserrors) / len(mserrors)

                print(
                    f"epoch: {epoch+1} || train_loss: {train_loss} || test_loss: {test_loss} || mean square error: {mse}"
                )
                losses.clear()
    if (epoch + 1) % 5 == 0:
        torch.save(model.state_dict(), f"upscaler_ver_2_e{epoch+1}.pt")

KeyboardInterrupt: 