In [None]:
from google.colab import drive
drive.mount('/content/drive')


In [None]:
!git clone https://github.com/NVlabs/Sana.git
%cd Sana
!pip install -q -r requirements.txt

In [None]:
import os, torch
from sana_pipeline import SanaPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = SanaPipeline.from_pretrained("nv-tlabs/sana-stable-diffusion").to(device)

prompts = [
    "Pixel art Ghibli style: Chicken rice curry dish, warm colors, cozy",
    "Pixel art Ghibli style: Pepperoni pizza, vibrant and tasty",
    "Pixel art Ghibli style: Fresh salad with avocado and tomato, pastel colors",
    "Pixel art Ghibli style: Classic burger with bacon and lettuce, charming",
    "Pixel art Ghibli style: Pasta with tomato sauce, cozy cottage style",
    "Pixel art Ghibli style: Sushi plate with tuna and cucumber, soft palette",
    "Pixel art Ghibli style: Steak with potatoes, rustic tavern style",
    "Pixel art Ghibli style: Carrot ginger soup, warm and comforting",
    "Pixel art Ghibli style: Chocolate cake topped with strawberries, whimsical",
    "Pixel art Ghibli style: Grilled salmon with asparagus, seaside tavern feel",
    "Pixel art Ghibli style: Cheese and spinach omelette, morning sunlight mood",
    "Pixel art Ghibli style: Pancakes with syrup, cozy breakfast",
    "Pixel art Ghibli style: Shrimp sautéed in garlic butter, coastal atmosphere",
    "Pixel art Ghibli style: Beef lasagna, warm family dinner scene",
    "Pixel art Ghibli style: Vegetable stir-fry noodles, vibrant street market",
    "Pixel art Ghibli style: Fruit salad with pineapple, mango, sunny picnic",
    "Pixel art Ghibli style: Roasted chicken with herbs, countryside tavern",
    "Pixel art Ghibli style: Mushroom risotto, woodland cottage vibe",
    "Pixel art Ghibli style: Fish tacos with fresh salsa, beachside charm",
    "Pixel art Ghibli style: Avocado toast with egg, cozy morning glow"
]


output_dir = "/content/drive/MyDrive/SANA_dataset"
os.makedirs(output_dir, exist_ok=True)

for idx, prompt in enumerate(prompts):
    img = pipe(prompt, guidance_scale=7.5).images[0]
    img.save(f"{output_dir}/dish_{idx}.png")
    print(f"Généré : {prompt}")

print("✅ Dataset prêt dans ton Drive.")


In [None]:
import torch
from torchvision import transforms
from PIL import Image
import numpy as np

# Chargement des images
dataset_path = "/content/drive/MyDrive/SANA_dataset"
images_files = sorted([os.path.join(dataset_path, f) for f in os.listdir(dataset_path)])

# Exemples de vecteurs conditionnels simplifiés
conditions = torch.eye(len(images_files)) # condition en one-hot (1 plat = 1 vecteur unique)

# Prétraitement images
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

images = torch.stack([transform(Image.open(f)) for f in images_files])

# Dataset final
train_conditions, train_images = conditions, images

print(f"Conditions: {train_conditions.shape}, Images: {train_images.shape}")


In [None]:
import torch.nn as nn
import torch.optim as optim

# Modèle simplifié : Condition (vecteur) → Image
class SimpleGenerator(nn.Module):
    def __init__(self, input_dim, output_channels=3):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_dim, 128*16*16),
            nn.ReLU()
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(128,64,4,2,1),
            nn.ReLU(),
            nn.ConvTranspose2d(64,32,4,2,1),
            nn.ReLU(),
            nn.ConvTranspose2d(32,output_channels,4,2,1),
            nn.Sigmoid()
        )

    def forward(self,x):
        x = self.fc(x).view(-1,128,16,16)
        x = self.conv(x)
        return x

gen = SimpleGenerator(input_dim=len(images_files)).to(device)
optimizer = optim.Adam(gen.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Entraînement rapide (100 epochs)
for epoch in range(100):
    optimizer.zero_grad()
    output = gen(train_conditions.to(device))
    loss = criterion(output, train_images.to(device))
    loss.backward()
    optimizer.step()
    if epoch % 20 == 0:
        print(f"Epoch {epoch}/100 - Loss : {loss.item():.4f}")

print("✅ Entraînement terminé.")


In [None]:
import matplotlib.pyplot as plt

# Test rapide (par exemple le plat n°2)
test_idx = 2
with torch.no_grad():
    generated_image = gen(train_conditions[test_idx:test_idx+1].to(device)).cpu()[0]

# Affichage
plt.imshow(np.transpose(generated_image.numpy(), (1,2,0)))
plt.title(f"Plat généré : {prompts[test_idx]}")
plt.axis('off')
plt.show()


In [None]:
dummy_input = train_conditions[0:1].to(device)
torch.onnx.export(
    gen, dummy_input, "/content/drive/MyDrive/cgan_plats.onnx",
    export_params=True,
    opset_version=11,
    input_names=['ingredients'],
    output_names=['generated_dish']
)

print("✅ Modèle exporté en ONNX dans ton Drive.")
