In [None]:
import sys
import os
sys.path.append('..')

import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np

from models.text_to_fashion import TextToFashionGenerator
from models.style_transfer import StyleTransferModel
# from models.virtual_tryon import VirtualTryOnModel
from models.trend_predictor import TrendPredictor

print("Models imported successfully!")

# ------------------------------------------------------------
# 1. Text-to-Fashion Generation
# ------------------------------------------------------------

# Initialize text-to-fashion model
text2fashion = TextToFashionGenerator()

# Generate designs from text
prompts = [
    "elegant black evening dress",
    "casual blue denim jacket",
    "floral summer dress",
    "minimalist white t-shirt"
]

fig, axes = plt.subplots(2, 2, figsize=(10, 10))
axes = axes.ravel()

for i, prompt in enumerate(prompts):
    generated_img = text2fashion.generate(prompt)
    axes[i].imshow(generated_img)
    axes[i].set_title(prompt)
    axes[i].axis('off')

plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 2. Style Transfer
# ------------------------------------------------------------

# Initialize style transfer model
style_transfer = StyleTransferModel()

# Create sample images
content_img = Image.new('RGB', (256, 256), 'lightblue')
style_img = Image.new('RGB', (256, 256), 'red')

# Apply style transfer
result_img = style_transfer.transfer(content_img, style_img)

# Display results
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(content_img)
axes[0].set_title('Content')
axes[0].axis('off')

axes[1].imshow(style_img)
axes[1].set_title('Style')
axes[1].axis('off')

axes[2].imshow(result_img)
axes[2].set_title('Result')
axes[2].axis('off')

plt.tight_layout()
plt.show()

# ------------------------------------------------------------
# 3. Trend Prediction
# ------------------------------------------------------------

# Initialize trend predictor
trend_predictor = TrendPredictor()

# Create sample fashion images
sample_images = [
    Image.new('RGB', (256, 256), 'red'),
    Image.new('RGB', (256, 256), 'blue'),
    Image.new('RGB', (256, 256), 'green')
]

# Predict trends
trends = trend_predictor.predict(sample_images)

print("Trend Analysis Results:")
print(f"Dominant Colors: {trends['colors']}")
print(f"Popular Styles: {trends['styles']}")
print(f"Trending Patterns: {trends['patterns']}")
print(f"Seasonal Prediction: {trends['season']}")
print(f"Trend Score: {trends['trend_score']:.2f}")
print(f"Popularity Score: {trends['popularity_score']:.2f}")
