In [None]:
import string
import torch.nn as nn
from tqdm import tqdm
import logging
from torch.utils.tensorboard import  SummaryWriter
import torch.nn.functional as F
import numpy as np
import os
import copy
import torch
import torchvision
from PIL import Image
import torch.functional
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch import optim
from torchvision.datasets import DatasetFolder, ImageFolder

In [None]:
class Diffusion:
    def __init__(self, noise_steps=400, beta_start=1e-4, beta_end=0.02, img_size=64, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        Ɛ = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * Ɛ, Ɛ

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n, labels, cfg_scale=3):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t, labels)
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels),
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)

In [None]:
class MyUNetConditioned(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256,num_classes=None, device="cuda"):
        super().__init__()
        self.time_dim = time_dim
        self.device = device

        self.init = DoubleConv(c_in,64)

        self.maxpool_conv1 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(64, 64, residual=True),
            DoubleConv(64, 128),
        )
        self.emb_layer1 = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim,128),
        )
        self.sa1 = nn.Sequential(
            SelfAttention(128,32)
        )

        self.maxpool_conv2 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(128, 128, residual=True),
            DoubleConv(128, 256),
        )
        self.emb_layer2 = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim,256),
        )
        self.sa1_1 = nn.Sequential(
            SelfAttention(256,16)
        )

        self.maxpool_conv2_1 = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(256, 256, residual=True),
            DoubleConv(256, 256),
        )
        self.emb_layer2_1 = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim,256),
        )


        #Bottleneck
        self.bottleneck = nn.Sequential(
            DoubleConv(256, 512),
            DoubleConv(512, 512),
            DoubleConv(512, 256),
        )


        self.up3 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.up_conv3 = nn.Sequential(
            DoubleConv(512, 512, residual=True),
            DoubleConv(512, 128),
        )
        self.emb_layer3 = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim,128),
        )
        self.sa2 = nn.Sequential(
            SelfAttention(128,16)
        )

        self.up3_1 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.up_conv3_1 = nn.Sequential(
            DoubleConv(256, 256, residual=True),
            DoubleConv(256, 64),
        )
        self.emb_layer3_1 = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim,64),
        )
        self.sa1_2 = nn.Sequential(
            SelfAttention(64,32)
        )

        self.up4 = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.up_conv4 = nn.Sequential(
            DoubleConv(128, 128, residual=True),
            DoubleConv(128, 64),
        )
        self.emb_layer4 = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim,64),
        )
        

        self.exit = nn.Conv2d(64,c_out,1)

        if num_classes is not None:
            self.label_emb = nn.Embedding(num_classes, time_dim)

    def pos_encoding(self, t, channels):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels))
        pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
        return pos_enc

    def forward(self, x, t, y):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.pos_encoding(t, self.time_dim)

        if y is not None:
            t += self.label_emb(y)

        x = self.init(x)


        x1 = self.maxpool_conv1(x)
        emb1 = self.emb_layer1(t)[:, :, None, None].repeat(1, 1, x1.shape[-2], x1.shape[-1])
        x1 = x1 + emb1
        x1 = self.sa1(x1)


        x2 = self.maxpool_conv2(x1)
        emb2 = self.emb_layer2(t)[:, :, None, None].repeat(1, 1, x2.shape[-2], x2.shape[-1])
        x2 = x2 + emb2
        x2 = self.sa1_1(x2)


        x2_1 = self.maxpool_conv2_1(x2)
        emb2_1 = self.emb_layer2_1(t)[:, :, None, None].repeat(1, 1, x2_1.shape[-2], x2_1.shape[-1])
        x2_1 = x2_1 + emb2_1

        #Bottleneck
        bottleneck = self.bottleneck(x2_1)


        x3 = self.up3(bottleneck)
        x3 = torch.cat((x3, x2), dim=1)
        x3 = self.up_conv3(x3)
        emb3 = self.emb_layer3(t)[:, :, None, None].repeat(1, 1, x3.shape[-2], x3.shape[-1])
        x3 = x3 + emb3
        x3 = self.sa2(x3)


        x3_1 = self.up3_1(x3)
        x3_1 = torch.cat((x3_1, x1), dim=1)
        x3_1 = self.up_conv3_1(x3_1)
        emb3_1 = self.emb_layer3_1(t)[:, :, None, None].repeat(1, 1, x3_1.shape[-2], x3_1.shape[-1])
        x3_1 = x3_1 + emb3_1
        x3_1 = self.sa1_2(x3_1)
        

        x4 = self.up3_1(x3_1)
        x4 = torch.cat((x4, x), dim=1)
        x4 = self.up_conv4(x4)
        emb4 = self.emb_layer4(t)[:, :, None, None].repeat(1, 1, x4.shape[-2], x4.shape[-1])
        x4 = x4 + emb4
        

        return self.exit(x4)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

n = 10  # Numero di immagini da generare per ogni etichetta
device = device
model = MyUNetConditioned(num_classes=22).to(device)

ckpt = torch.load(r"/kaggle/input/ema-ckpt-final/ema_ckpt.pt")
model.load_state_dict(ckpt)

diffusion = Diffusion(img_size=64, device=device)

save_dir = "generated_images_1"  # Cartella di destinazione per le immagini generate
os.makedirs(save_dir, exist_ok=True)  # Crea la cartella se non esiste

# Definisci il dizionario con le corrispondenze tra chiavi e label
dizionario = {
    "00000": 0,
    "00001": 1,
    "00010": 2,
    "00011": 3,
    "00101": 4,
    "00111": 5,
    "01001": 6,
    "01010": 7,
    "01011": 8,
    "01101": 9,
    "01111": 10,
    "10000": 11,
    "10001": 12,
    "10010": 13,
    "10011": 14,
    "10101": 15,
    "10111": 16,
    "11000": 17,
    "11001": 18,
    "11011": 19,
    "11101": 20,
    "11111": 21
}
letters = string.ascii_uppercase  # Lettere dell'alfabeto
n = 10
with open("/kaggle/input/testtxt/test.txt", "r") as file:
    for riga in file:
        riga = riga.strip()
        if riga:
            elementi = riga.split(";")
            if len(elementi) == 2:
                nome_immagine, chiave = elementi
                if chiave in dizionario:
                    label = dizionario[chiave]
                    print(chiave)
                    y = torch.Tensor([label] * n).long().to(device)
                    x = diffusion.sample(model, n, y, cfg_scale=3)
                    for i in range(n):                        
                        image = x[i].cpu().permute(1, 2, 0).numpy()
                        file_name = f"{nome_immagine}_{letters[i]}.jpg"
                        save_path = os.path.join(save_dir, file_name)
                        plt.imsave(save_path, image) 
                else:
                    print(f"La chiave {chiave} non è presente nel dizionario.")
            else:
                print(f"La riga {riga} non ha il formato corretto.")