# Train LoRA on Colab

Dataset already on Drive: style_1_vietnamese (40 images)

Training time: 45-60 minutes on T4

## Step 1: Install Dependencies

In [None]:
!pip install -q diffusers transformers accelerate peft torch torchvision datasets

import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
else:
    print("WARNING: No GPU!")

## Step 2: Mount Google Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

print("Drive mounted")

## Step 3: Config

In [None]:
# Paths
dataset_path = "/content/drive/MyDrive/NCKH_Datasets/style_1_vietnamese"
output_path = "/content/drive/MyDrive/NCKH_LoRAs/lora_vietnamese"

# Training params
EPOCHS = 20
BATCH_SIZE = 1
LEARNING_RATE = 1e-4
RANK = 8
MODEL_ID = "runwayml/stable-diffusion-v1-5"

print(f"Dataset: {dataset_path}")
print(f"Output: {output_path}")
print(f"Epochs: {EPOCHS}")
print(f"LoRA Rank: {RANK}")

## Step 4: Verify Dataset

In [None]:
from pathlib import Path
import json

# Count images
images = list(Path(dataset_path).glob("*.jpg")) + list(Path(dataset_path).glob("*.png"))
print(f"Found {len(images)} images")

if len(images) < 10:
    print("ERROR: Not enough images!")
else:
    print("Dataset OK")
    
# Create metadata if not exists
metadata_file = Path(dataset_path) / "metadata.json"

if not metadata_file.exists():
    print("Creating metadata...")
    metadata = []
    for img_path in images:
        metadata.append({
            "image": img_path.name,
            "prompt": "beautiful vietnamese landscape, natural scenery, high quality",
            "caption": "vietnamese landscape photography"
        })
    
    with open(metadata_file, "w") as f:
        json.dump(metadata, f, indent=2)
    
    print("Metadata created")
else:
    print("Metadata already exists")

## Step 5: Load Models

In [None]:
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import CLIPTextModel, CLIPTokenizer

print("Loading models...")

tokenizer = CLIPTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(MODEL_ID, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(MODEL_ID, subfolder="unet")
noise_scheduler = DDPMScheduler.from_pretrained(MODEL_ID, subfolder="scheduler")

# Freeze
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

print("Models loaded")

## Step 6: Setup LoRA

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=RANK,
    lora_alpha=32,
    target_modules=["to_k", "to_q", "to_v", "to_out.0"],
    lora_dropout=0.0,
)

unet = get_peft_model(unet, lora_config)
unet.print_trainable_parameters()

print("LoRA configured")

## Step 7: Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

class ImageDataset(Dataset):
    def __init__(self, data_dir, tokenizer):
        self.data_dir = Path(data_dir)
        self.tokenizer = tokenizer
        
        with open(self.data_dir / "metadata.json") as f:
            self.metadata = json.load(f)
        
        self.transform = transforms.Compose([
            transforms.Resize(512),
            transforms.CenterCrop(512),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    
    def __len__(self):
        return len(self.metadata)
    
    def __getitem__(self, idx):
        item = self.metadata[idx]
        img = Image.open(self.data_dir / item["image"]).convert("RGB")
        img = self.transform(img)
        
        input_ids = self.tokenizer(
            item["prompt"],
            max_length=77,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        ).input_ids[0]
        
        return {"pixel_values": img, "input_ids": input_ids}

dataset = ImageDataset(dataset_path, tokenizer)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

print(f"Dataset ready: {len(dataset)} images")

## Step 8: Training Loop

In [None]:
from accelerate import Accelerator
import torch.nn.functional as F
from tqdm.auto import tqdm

accelerator = Accelerator(mixed_precision="fp16")
optimizer = torch.optim.AdamW(unet.parameters(), lr=LEARNING_RATE)

unet, optimizer, dataloader = accelerator.prepare(unet, optimizer, dataloader)
vae.to(accelerator.device)
text_encoder.to(accelerator.device)

print("Starting training...")
print(f"Total steps: {EPOCHS * len(dataloader)}")

unet.train()
global_step = 0

for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    
    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        with accelerator.accumulate(unet):
            # Encode to latent
            latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
            latents = latents * 0.18215
            
            # Noise
            noise = torch.randn_like(latents)
            timesteps = torch.randint(0, 1000, (latents.shape[0],), device=latents.device).long()
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
            
            # Text embeddings
            encoder_hidden_states = text_encoder(batch["input_ids"])[0]
            
            # Predict noise
            model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
            
            # Loss
            loss = F.mse_loss(model_pred.float(), noise.float())
            
            # Backward
            accelerator.backward(loss)
            optimizer.step()
            optimizer.zero_grad()
        
        global_step += 1
        
        if global_step % 50 == 0:
            print(f"Step {global_step}, Loss: {loss.item():.4f}")

print("\nTraining complete!")

## Step 9: Save LoRA

In [None]:
!mkdir -p "{output_path}"

# Unwrap and save
unet_lora = accelerator.unwrap_model(unet)
unet_lora.save_pretrained(output_path)

print(f"LoRA saved to: {output_path}")

# List saved files
!ls -lh "{output_path}"

## Step 10: Test LoRA

In [None]:
from diffusers import StableDiffusionPipeline

print("Testing LoRA...")

# Load pipeline
pipe = StableDiffusionPipeline.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16
).to("cuda")

# Load LoRA
pipe.unet.load_attn_procs(output_path)

# Test generation
prompt = "beautiful vietnamese landscape with mountains and rice fields"

image = pipe(
    prompt,
    num_inference_steps=30,
    guidance_scale=7.5
).images[0]

image.save("/content/test_lora.png")
print("Test image saved")

# Display
from IPython.display import display
display(image)

## Step 11: Download LoRA to Local

In [None]:
# Zip LoRA weights
!cd "{output_path}" && zip -r /content/lora_vietnamese.zip .

# Download
from google.colab import files
files.download('/content/lora_vietnamese.zip')

print("Download started!")
print("File: lora_vietnamese.zip (20-50MB)")
print("Extract and copy to: models/lora/")

## Notes

After download:
1. Extract lora_vietnamese.zip
2. Copy pytorch_lora_weights.safetensors to models/lora/vietnamese.safetensors
3. Test in local app

LoRA is saved on Drive at: /content/drive/MyDrive/NCKH_LoRAs/lora_vietnamese/