In [34]:
from picturedamagerr import PictureDamager
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.optim import Adam
# from utils import get_image_cluster
from torch import nn
import einops
import torch
import joblib
import numpy as np
import os

In [35]:
kmeans = joblib.load("kmeans_20_clusters.pkl")

In [36]:
def get_image_cluster(latents):
    # b_latent = b_latent.view(b_latent.size(0), -1)  
    latents = latents.cpu().numpy() 
    return kmeans.predict(latents)

In [37]:
def setup_device():
    if torch.cuda.is_available():
        device = torch.device('cuda')
        # Set default tensor type for cuda
        torch.set_default_dtype(torch.float32)
    elif torch.backends.mps.is_available():
        device = torch.device('mps')
    else:
        device = torch.device('cpu')
        # Ensure we're using float32 on CPU

        torch.set_default_dtype(torch.float64)
    return device

device = setup_device()

print(f"Using {device} device")

# Set random seed for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)


Using cuda device


In [39]:
from torchvision import models
class Encoder(nn.Module):
  @staticmethod
  def ConvBlock(in_channels:int,out_channels:int):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1),
        nn.ReLU(True)
    )
  def __init__(self,latent_dim:int=32768) -> None:
    super().__init__()
    self.latent_dim=latent_dim
    resnet34 = models.resnet34(pretrained=True)
    self.model = nn.Sequential(
        *list(resnet34.children())[:-2],
        nn.Conv2d(512,52,kernel_size=1,stride=1,padding=0)
        )
  def forward(self,x):
    x=self.model(x)
    return x
class Decoder(nn.Module):
  @staticmethod
  def ConvBlock(in_channels:int,out_channels:int):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(True)
    )
  def __init__(self,latent_dim:int=32768) -> None:
    super().__init__()
    self.latent_dim=latent_dim
    self.model=nn.Sequential(
            nn.Conv2d(52,512,kernel_size=1,stride=1,padding=0),
            nn.Upsample(scale_factor=2, mode='nearest'),
            Decoder.ConvBlock(512,256),
            nn.Upsample(scale_factor=2, mode='nearest'),
            Decoder.ConvBlock(256,128),

            nn.Upsample(scale_factor=2, mode='nearest'),
            Decoder.ConvBlock(128,64),
            nn.Upsample(scale_factor=2, mode='nearest'),
            Decoder.ConvBlock(64,64),
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.ConvTranspose2d(64,3,kernel_size=3,stride=1,padding=1),

            nn.Sigmoid()
    )

  def forward(self,x):
    x=self.model(x)
    return x
class ArtAutoEncoder(nn.Module):

    def __init__(self):
      #Latent raczej koło 2000 (aktualnie 52*7*7)
        super(ArtAutoEncoder, self).__init__()

        self.encoder=Encoder()


        self.decoder=Decoder()
    def forward(self, x):
        x = self.encoder(x)

        x = self.decoder(x)
        return x

In [40]:
masks_path = "masksnpy"
mini_MASKS = [np.load(os.path.join(masks_path, f)) for f in os.listdir(masks_path) if f.endswith(".npy")]  

In [41]:
damager = PictureDamager(mini_MASKS, 1/16)

In [42]:
def add_4th_channel(img):
    if not isinstance(img, torch.Tensor):
        img = torch.from_numpy(img).permute(2, 0, 1).float()

    C, H, W = img.shape

    mask = damager.generate_random_mask((H, W)) 
    if not isinstance(mask, torch.Tensor):
        mask = torch.from_numpy(mask).float().to(device) 

    mask = mask.unsqueeze(0)  

    img4 = torch.cat([img, mask], dim=0)
    return img4.to(device)

In [43]:
def damage_image(img4):
    rgb = img4[:3]  
    mask = img4[3]  

    mask = mask.round() 

    rgb_damaged = rgb.clone()  
    rgb_damaged[:, mask == 1] = 255 

    return rgb_damaged.to(device)

In [44]:
model_for_clusters=ArtAutoEncoder().to(device)
name="resnet34_autoencoder"
model_for_clusters.load_state_dict(torch.load(f"autoencoder/{name}.pth",map_location=device))
model_for_clusters=model_for_clusters.float()

model_for_clusters.eval()

  model_for_clusters.load_state_dict(torch.load(f"autoencoder/{name}.pth",map_location=device))


ArtAutoEncoder(
  (encoder): Encoder(
    (model): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=

In [45]:
class LitModel(pl.LightningModule):
    def __init__(self, learning_rate=1e-3, in_channels=3, out_channels=3, embedding_dim=4):
        super(LitModel, self).__init__()

        self.model = ArtAutoEncoder()
        self.learning_rate = learning_rate
        self.conv_ = nn.Conv2d(in_channels + embedding_dim, out_channels, kernel_size=3)

        self.embedder = nn.Embedding(num_embeddings=20, embedding_dim=embedding_dim)

        self.save_hyperparameters()

    def forward(self, x):
        x = self.conv_(x)
        x = self.model(x)
        return x

    def add_embedding_dims(self, x, clusters):
        embeddings = self.embedder(clusters)
        embeddings = einops.repeat(embeddings, 'batch_size embedding_dim -> batch_size embedding_dim h w', h=256, w=256)
        x_with_embeddings = torch.cat([embeddings, x], dim=1)

        return x_with_embeddings

    # DRY - Don't Repeat Yourself
    def step(self, batch, batch_idx):
        x = batch['image'].to(device).float()
        # print(x.shape)

        batch_with_4th_channel = torch.stack([add_4th_channel(img) for img in x])
        # print(batch_with_4th_channel.shape)
        x = torch.stack([damage_image(img) for img in batch_with_4th_channel])
        # print(x.shape)

        with torch.inference_mode():
            latents = model_for_clusters.encoder(x)
        # print(latents.shape)
        latents= latents.view(latents.shape[0], -1)

        

        clusters = torch.from_numpy(get_image_cluster(latents)).to(device)
        x_with_embeddings = self.add_embedding_dims(x, clusters)

        x_recon = self(x_with_embeddings)

        loss = F.mse_loss(x_recon, x)

        return loss

    def training_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx)
        self.log('train_loss', loss)

        return loss

    def validation_step(self, batch, batch_idx):
        loss = self.step(batch, batch_idx)
        self.log('validation_loss', loss)

        return loss

    def configure_optimizers(self):
        # Resnet - posiada pretrenowane parametry, być może nie ma sensu uczyć ich razem z parametrami jeszcze nie
        # trenowanymi. Jeśli chcemy użyć pretrenowanego resneta, to może warto rozważyć wpierw zamrożenie wag
        # pretrenowanych, a po nauczeniu modelu dokonać ich fine-tuningu.
        return Adam(self.parameters(), lr=self.learning_rate)

In [46]:
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

In [47]:
inpainting_model = LitModel()

In [48]:
from datasets import load_dataset

full_dataset = load_dataset("Artificio/WikiArt_Full").with_format("torch")


full_dataset=full_dataset['train']


train_test_split = full_dataset.train_test_split(test_size=0.2)
train_dataset = train_test_split['train']
test_dataset = train_test_split['test']

test_val_split = test_dataset.train_test_split(test_size=0.5)
test_dataset = test_val_split['train']
val_dataset = test_val_split['test']


In [49]:
batch_size = 32

enable_pin_memory = True

number_of_workers =  3

In [50]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory = enable_pin_memory, num_workers=number_of_workers,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory = enable_pin_memory, num_workers=number_of_workers,shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, pin_memory = enable_pin_memory, num_workers=number_of_workers,shuffle=True)

In [51]:
from lightning.pytorch.loggers import CometLogger

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

EXPERIMENT_NAME = "inpainting training"

comet_logger = CometLogger(
    api_key= "0RfBkzG1OWWq2b8BwzxjFp0Q2",
    project_name="inpainting",
    experiment_name=EXPERIMENT_NAME,
)

best_checkpoint = ModelCheckpoint(
    monitor='validation_loss',
    dirpath=f'checkpoints/{EXPERIMENT_NAME}/',
    filename='model-{epoch:02d}-{validation_loss:.2f}',
    save_top_k=1,
    mode='min'
)
# model-epoch=14-validation_loss=0.2

last_checkpoint = ModelCheckpoint(
    dirpath=f'checkpoints/{EXPERIMENT_NAME}/',
    filename='model-{epoch:02d}',
    save_top_k=1,
    every_n_epochs=1,
)
# model-epoch=20

early_stopping = EarlyStopping(
    monitor='validation_loss',
    patience=5,
    mode='min'
)

trainer = pl.Trainer(
    max_epochs=30,
    callbacks=[last_checkpoint, best_checkpoint, early_stopping],
    logger=comet_logger
)


trainer.fit(inpainting_model, train_loader, test_loader)

CometLogger will be initialized in online mode
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/elpolaco/inpainting/34ae6091ecc04bfe8f7b002ea710a430

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name     | Type           | Params | Mode 
----------------------------------------------------
0 | model    | ArtAutoEncoder | 22.9 M | train
1 | conv_    | Conv2d         | 192    | train
2 | embedder | Embedding      | 80     | train
----------------------------------------------------
22.9 M    Trainable params
0         Non-trainable params
22.9 M    Total params
91.709    Total estimated model params size (MB)
145       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\peter\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : inpainting training
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/elpolaco/inpainting/34ae6091ecc04bfe8f7b002ea710a430
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Created from : pytorch-lightning
[1;38;5;39mCOMET INFO:[0m     Name         : inpainting training
[1;38;5;39mCOMET INFO:[0m   Parameters:
[1;38;5;39mCOMET INFO:[0m     embedding_dim : 4
[1;38;5;39mCOMET INFO:[0m     in_channels   : 3
[1;38;5;39mCOMET INFO:[0m     learn

NameError: name 'exit' is not defined

: 