In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import Places365
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from PIL import ImageTransform
import lightning as L
from lightning.pytorch.loggers import WandbLogger
import wandb

In [2]:
permutations = np.load("naroozi_perms_100_patches_9_max.npy")
permutations -= 1
permutations

array([[7, 6, 5, 8, 4, 3, 2, 1, 0],
       [8, 7, 6, 5, 3, 4, 1, 0, 2],
       [6, 8, 7, 4, 5, 2, 0, 3, 1],
       [5, 4, 8, 6, 1, 0, 7, 2, 3],
       [4, 5, 3, 2, 0, 1, 6, 7, 8],
       [3, 2, 1, 0, 8, 7, 4, 6, 5],
       [2, 3, 0, 1, 7, 8, 5, 4, 6],
       [1, 0, 4, 3, 2, 6, 8, 5, 7],
       [0, 1, 2, 7, 6, 5, 3, 8, 4],
       [8, 7, 6, 5, 2, 1, 3, 4, 0],
       [7, 8, 5, 3, 6, 2, 1, 0, 4],
       [6, 5, 4, 1, 8, 0, 7, 2, 3],
       [5, 6, 2, 7, 0, 8, 4, 3, 1],
       [4, 3, 8, 2, 1, 5, 0, 6, 7],
       [3, 1, 0, 4, 5, 6, 2, 7, 8],
       [2, 0, 7, 8, 4, 3, 6, 1, 5],
       [1, 2, 3, 0, 7, 4, 8, 5, 6],
       [0, 4, 1, 6, 3, 7, 5, 8, 2],
       [8, 7, 5, 1, 0, 3, 2, 6, 4],
       [7, 8, 4, 0, 1, 5, 3, 2, 6],
       [6, 5, 2, 3, 7, 1, 4, 0, 8],
       [5, 2, 8, 7, 6, 0, 1, 4, 3],
       [4, 1, 0, 5, 8, 2, 6, 3, 7],
       [3, 6, 7, 4, 2, 8, 5, 1, 0],
       [2, 4, 6, 8, 3, 7, 0, 5, 1],
       [0, 3, 1, 2, 4, 6, 7, 8, 5],
       [1, 0, 3, 6, 5, 4, 8, 7, 2],
       [8, 7, 4, 1, 6, 2, 0,

In [2]:
# ds =  Places365("places365", download=True, split="val")
ds =  Places365("places365", split="val")

In [None]:
# meg a kicsi is tul nagy
# ds = Places365("places365_2", download=True, small=True)

In [4]:
img = ds[0][0]
img = img.resize((600, 600))

w = img.size[0] / 3
h = img.size[1] / 3

tiles = []

T = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

for i in range(3):
    for j in range(3):
        x_min, y_min = j*w, i*h
        x_max, y_max = (j+1)*w, (i+1)*h
        tile = img.crop([x_min, y_min, x_max, y_max])
        tile_t = T(tile)
        print(tile_t.max(), tile_t.min())
        tiles.append(tile)

tensor(2.6400) tensor(-2.0494)
tensor(2.6400) tensor(-1.9809)
tensor(2.6400) tensor(-2.1008)
tensor(2.6400) tensor(-2.1179)
tensor(2.6400) tensor(-2.1179)
tensor(2.6400) tensor(-2.1179)
tensor(2.6400) tensor(-2.1179)
tensor(2.6400) tensor(-2.1179)
tensor(2.6400) tensor(-2.1179)


In [None]:
permuted_tile_tensor = np.array(tiles)[permutations[0],...]
permuted_tiles = [Image.fromarray(tile) for tile in permuted_tile_tensor]

In [None]:
fig, axes = plt.subplots(3, 3)

for i, ax in enumerate(axes.flat):
    ax.imshow(permuted_tiles[i])
    ax.axis("off")

In [5]:
N_PERMUTATIONS = 10

In [6]:
class TileDataset(Dataset):
    def __init__(self, dataset, n_permutations=100):
        self.dataset = dataset
        # all possible permutations. The get() function will choose randomly, and return the label
        self.permutations = np.load("naroozi_perms_100_patches_9_max.npy")
        self.permutations -= 1 # 0 based indexing
        self.n_permutations = n_permutations
        self.resize_to = 600
        self.T = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # imageNet standard
        ])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, _ = self.dataset[idx]

        img = img.resize((self.resize_to, self.resize_to))

        w = img.size[0] // 3
        h = img.size[1] // 3

        tiles = []

        for i in range(3):
            for j in range(3):
                x_min, y_min = j * w, i * h
                x_max, y_max = (j + 1) * w, (i + 1) * h
                tile = img.crop([x_min, y_min, x_max, y_max])
                tile = self.T(tile)
                tiles.append(tile.unsqueeze(0))

        perm_idx = np.random.choice(self.n_permutations)
        permutation = self.permutations[perm_idx]

        permuted_tile_tensor = torch.cat(tiles, dim=0)[permutation, ...]

        perm_label = np.zeros(self.n_permutations)
        perm_label[perm_idx] = 1

        return permuted_tile_tensor, perm_label

tile_dataset = TileDataset(ds, N_PERMUTATIONS)

# Create a DataLoader
batch_size = 32  # Adjust the batch size as needed
tile_dataloader = DataLoader(tile_dataset, batch_size=batch_size, shuffle=True, num_workers=8)

In [None]:
tile_dataset[0][0].shape, tile_dataset[0][1].shape

In [8]:
CONV1_C = 96
CONV2_C = 256
CONV3_C = 384
CONV4_C = 384
CONV5_C = 256
FC1 = 256 #1024
FC2 = 256 #4096

In [9]:
# Structure
# conv
# local size ?
# output has 11x11 features for 200x200 input
backbone = nn.Sequential(
    # conv1
    nn.Conv2d(3, CONV1_C, kernel_size=11, stride=2, padding=0),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.LocalResponseNorm(5),
    # conv2
    nn.Conv2d(CONV1_C, CONV2_C, kernel_size=5, padding=2, groups=2),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2),
    nn.LocalResponseNorm(5),
    # conv3
    nn.Conv2d(CONV2_C, CONV3_C, kernel_size=3, padding=1),
    nn.ReLU(inplace=True),
    # conv4
    nn.Conv2d(CONV3_C, CONV4_C, kernel_size=3, padding=1, groups=2),
    nn.ReLU(inplace=True),
    # conv5
    nn.Conv2d(CONV4_C, CONV5_C, kernel_size=3, padding=1, groups=2),
    nn.ReLU(inplace=True),
    nn.MaxPool2d(kernel_size=3, stride=2)
)

# fc6 in original code, feeded recieves each tile individually 
fc1 = nn.Sequential(
    nn.Linear(CONV5_C*11*11, FC1),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5),
)

# fc7 + classifier
fc_head = nn.Sequential(
    # fc7
    nn.Linear(9*FC1, FC2),
    nn.ReLU(inplace=True),
    nn.Dropout(p=0.5),
    # classifier
    nn.Linear(FC2, N_PERMUTATIONS),
    # nn.Softmax(dim=0) # no softmax bcos cross entropy loss expects unnormalized logits
)

In [10]:
class JigsawTorchModel(nn.Module):
    def __init__(self, backbone, fc1, fc_head, num_permutations=100):
        super().__init__()
        self.backbone = backbone  # Replace with your actual backbone
        self.fc1 = fc1
        self.fc_head = fc_head
        self.num_permutations = num_permutations

    def forward(self, x):
        # force batch size
        x = x.view(-1, 9, 3, 200, 200)
        B = x.shape[0]

        x = x.transpose(0,1)

        x_list = []
        for i in range(9):
            z = self.backbone(x[i])
            z = self.fc1(z.view(B, -1))
            z = z.view(B, 1, -1)
            x_list.append(z)

        x = torch.cat(x_list,1)
        x = x.view(B, -1)
        x = self.fc_head(x)

        return x

In [11]:
class JigsawLightningModel(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.loss = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.loss(outputs, labels)
        self.log("train loss", loss)
        return loss

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-2)

In [12]:
# Instantiate the Lightning Model
jigsaw_model = JigsawLightningModel(JigsawTorchModel(backbone, fc1, fc_head))

In [13]:
# Define a PyTorch Lightning Trainer
wandb_logger = WandbLogger(project="DL-HF")
trainer = L.Trainer(max_epochs=10, logger=wandb_logger, log_every_n_steps=4)  # Adjust max_epochs and gpus as needed
# Train the model
trainer.fit(jigsaw_model, tile_dataloader)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpeter-istvan[0m ([33mhey-chatgpt-suggest-team-name[0m). Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params
-------------------------------------------
0 | model | JigsawTorchModel | 10.9 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
10.9 M    Trainable params
0         Non-trainable params
10.9 M    Total params
43.427    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

/home/i/BME/9/melytanulas/DeepLearningHW23/venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [13]:
wandb_logger.finalize(0)