In [None]:
import os, io, random
from src.data.load_data import *
from src.data.data_utils import *


os.environ["OMP_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
torch.set_num_threads(1)
torch.backends.cudnn.benchmark = True 

# Coniguración
BATCH_SIZE   = 32
PIN_MEMORY   = True
DROP_LAST    = True
SEED         = 7

# Tamaño de entrenamiento
SIZE         = 256
FINAL_SIZE   = 252  # usamos RandomResizedCrop directo a 252

# Targets de reducción (ajusta a gusto)
CONTENT_KEEP = 10000   # COCO
STYLE_KEEP   = 30000   # WikiArt

# Normalización ImageNet
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

# Auto-select de workers (Kaggle suele tener 2 vCPU)
CPU_COUNT = os.cpu_count() or 2
NUM_WORKERS = 2 if CPU_COUNT <= 2 else min(4, CPU_COUNT - 1)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"



# ======================================================
# Carga de COCO (content)
# ======================================================
coco_hf = load_dataset("phiyodr/coco2017", split="train")
coco_img_col = detect_image_col(coco_hf)
coco_hf = filter_valid_images(coco_hf, coco_img_col)
coco_hf = cast_to_image(coco_hf, coco_img_col)
coco_ds = HFDataset(coco_hf, img_key="image", transform=content_tf)
content_loader = make_loader(coco_ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

# Recortamos COCO a CONTENT_KEEP
content_loader = truncate_dataloaders(content_loader, None, n=CONTENT_KEEP, seed=77)

# ======================================================
# Carga de WikiArt (style) con filtro por estilos útiles
# ======================================================
wiki_hf = load_dataset("davanstrien/wikiart-resized", split="train")
wiki_img_col = detect_image_col(wiki_hf)
wiki_hf = filter_valid_images(wiki_hf, wiki_img_col)

print("Total WikiArt original:", len(wiki_hf))
print("Columnas WikiArt:", wiki_hf.column_names)


style_feat = wiki_hf.features["style"] 
style_names = style_feat.names      

print("Estilos disponibles:", style_names)


GOOD_PATTERNS = [
    "impression",      
    "expression",      
    "fauvism",
    "baroque",
    "romantic",        
    "symbol",         
    "realism",
    "northern_renaissance",
    "naive_art",        
    "art_nouveau",]

# IDs de estilos buenos según el nombre
good_style_ids = []
for idx, name in enumerate(style_names):
    name_low = name.lower()
    if any(pat in name_low for pat in GOOD_PATTERNS):
        good_style_ids.append(idx)

print("IDs de estilos seleccionados:", good_style_ids)
print("Nombres de estilos seleccionados:",
      [style_names[i] for i in good_style_ids])

def keep_good_styles(example):
    sid = int(example["style"])
    return sid in good_style_ids


wiki_good = wiki_hf.filter(keep_good_styles)
print("Total WikiArt tras filtro por estilos buenos:", len(wiki_good))
STYLE_KEEP = 5000

if len(wiki_good) > STYLE_KEEP:
    try:
        idx_good = stratified_pick(
            wiki_good,
            group_col="artist",
            target_total=STYLE_KEEP,
            seed=77)
        
    except ValueError:
        idx_good = stratified_pick(
            wiki_good,
            group_col="style",
            target_total=STYLE_KEEP,
            seed=77)
        
    wiki_good = wiki_good.select(idx_good)

print("Total final de estilos usados para NST:", len(wiki_good))

# Cast a image + dataset PyTorch
wiki_good = wiki_good.map(add_brightness_stats, num_proc=1)
wiki_good = wiki_good.filter(keep_reasonable_brightness, num_proc=1)
print("Total WikiArt tras filtro style + brillo/contraste:", len(wiki_good))
wiki_good = cast_to_image(wiki_good, wiki_img_col)
wiki_ds = HFDataset(wiki_good, img_key="image", transform=style_tf)
style_loader = make_loader(wiki_ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)



train_iter = make_train_iterator(content_loader, style_loader)
xb_c, xb_s = next(iter(train_iter))
print("paired content:", xb_c.shape, "| paired style:", xb_s.shape)
print("len(content_loader) =", len(content_loader))
print("len(style_loader) =", len(style_loader))

In [None]:
from src.model.styA2kNet import * 
from src.model.loss import * 
from src.training.train_model import *

device = "cuda" if torch.cuda.is_available() else "cpu"

model = StyA2KNet(device=device).to(device)
vgg_loss_extractor = build_vgg_loss_extractor(device)
criterion = PerceptualLoss(vgg_loss_extractor).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

if torch.cuda.device_count() > 1:
    print(f"Usando {torch.cuda.device_count()} GPUs con DataParallel")
    model = torch.nn.DataParallel(model)
    
state = train_stya2k(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    content_loader=content_loader,
    style_loader=style_loader,
    device=device,
    epochs=20,
    amp_enabled=True,
    amp_dtype="bp16",
    grad_clip=1.0,
    log_every=100,
    run_name="StyA2K_v1",
    sample_every=2,
    sample_dir="samples_stya2k")

## Usando 2 GPUs con DataParallel

Run: StyA2K_v1
Device: cuda | AMP: True (bp16) | epochs: 20 | start_epoch=0

---

## ep | step | loss | content | style | imgs | imgs/s | time

[step 100/312] loss=29.8137 content=16.4376 style=13.3761 time=608.7s

[step 200/312] loss=27.9632 content=14.7957 style=13.1675 time=1184.1s

[step 300/312] loss=26.4480 content=13.8558 style=12.5922 time=1759.3s

0 | 312 | 26.32192 | 13.77635 | 12.54558 | 9984 | 5.5 | 30:28

└─ [SAMPLE] grid guardada en samples_stya2k/StyA2K_v1_e000.png

[step 100/312] loss=25.0838 content=11.5879 style=13.4959 time=600.3s

[step 200/312] loss=23.1558 content=11.3300 style=11.8258 time=1176.1s

[step 300/312] loss=22.4681 content=11.1548 style=11.3133 time=1751.4s

1 | 624 | 22.61840 | 11.14136 | 11.47704 | 9984 | 5.5 | 30:20

[step 100/312] loss=20.9399 content=10.5985 style=10.3414 time=581.0s

[step 200/312] loss=21.0096 content=10.4871 style=10.5225 time=1156.6s

[step 300/312] loss=20.6936 content=10.3809 style=10.3127 time=1732.5s

2 | 936 | 20.67360 | 10.36999 | 10.30361 | 9984 | 5.5 | 30:01

└─ [SAMPLE] grid guardada en samples_stya2k/StyA2K_v1_e002.png

[step 100/312] loss=21.4734 content=10.1112 style=11.3622 time=598.1s

[step 200/312] loss=20.4882 content=9.9646 style=10.5236 time=1172.0s

[step 300/312] loss=20.2428 content=9.9101 style=10.3327 time=1746.8s

3 | 1248 | 20.21931 | 9.90737 | 10.31194 | 9984 | 5.5 | 30:15

[step 100/312] loss=19.9675 content=9.6761 style=10.2914 time=579.8s


In [None]:
state = train_stya2k(
    model=model,                 
    criterion=criterion,
    optimizer=optimizer,          
    content_loader=content_loader,
    style_loader=style_loader,
    device=device,
    epochs=10,                  
    amp_enabled=True,
    amp_dtype="fp16",
    grad_clip=1.0,
    log_every=100,
    run_name="StyA2K_v1",
    sample_every=2,
    sample_dir="samples_stya2k",
    # --- claves de reanudación ---
    start_epoch=state["last_epoch"] + 1,
    init_global_step=state["global_step"],
    scaler_state_dict=state["scaler_state_dict"],
)